diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c9abe95e860..e33ccc9fb50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,9 +50,12 @@ jobs: steps: - name: Install environment run: | - yum -y update - yum -y install centos-release-scl-rh epel-release - yum -y install java-1.8.0-openjdk-devel devtoolset-7 rh-git218 patch python36-devel python36-numpy python36-pip python36-six + echo Not updating glibc since CUDA fails with updated versions + GLIBC="glibc glibc-common glibc-devel glibc-headers" + yum --disablerepo updates -y install $GLIBC + yum -x "$GLIBC" -y update + yum -x "$GLIBC" -y install centos-release-scl-rh epel-release + yum -x "$GLIBC" -y install java-1.8.0-openjdk-devel devtoolset-7 rh-git218 patch perl-Data-Dumper python36-devel python36-numpy python36-pip python36-six echo Downloading Maven curl -L https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz -o $HOME/apache-maven-3.6.3-bin.tar.gz tar xzf $HOME/apache-maven-3.6.3-bin.tar.gz -C /opt/ diff --git a/.gitignore b/.gitignore index 098ce71c656..cdbd28eca7c 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,5 @@ gradleBuild .classpath **/target +.tf_configure.bazelrc +.clwb/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..d13e1c4aedb --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,110 @@ +# Building and Contributing to TensorFlow Java + +## Building + +To build all the artifacts, simply invoke the command `mvn install` at the root of this repository (or the Maven command of your choice). It is also +possible to build artifacts with support for MKL enabled with +`mvn install -Djavacpp.platform.extension=-mkl` or CUDA with `mvn install -Djavacpp.platform.extension=-gpu` +or both with `mvn install -Djavacpp.platform.extension=-mkl-gpu`. + +When building this project for the first time in a given workspace, the script will attempt to download +the [TensorFlow runtime library sources](https://github.com/tensorflow/tensorflow) and build of all the native code for your platform. This requires a +valid environment for building TensorFlow, including the [bazel](https://bazel.build/) +build tool and a few Python dependencies (please read [TensorFlow documentation](https://www.tensorflow.org/install/source) +for more details). + +This step can take multiple hours on a regular laptop. It is possible though to skip completely the native build if you are working on a version that +already has pre-compiled native artifacts for your platform [available on Sonatype OSS Nexus repository](#Snapshots). You just need to activate +the `dev` profile in your Maven command to use those artifacts instead of building them from scratch +(e.g. `mvn install -Pdev`). + +Modifying the native op generation code (not the annotation processor) or the JavaCPP configuration (not the abstract Pointers) will require a +complete build could be required to reflect the changes, otherwise `-Pdev` should be fine. + +### Native Builds + +In some cases, like when adding GPU support or re-generating op classes, you will need to re-build the native library. 99% of this is building +TensorFlow, which by default is configured for the [CI](.github/workflows/ci.yml). The build configuration can be customized using the same methods as +TensorFlow, so if you're building locally, you may need to clone the [tensorflow](https://github.com/tensorflow/tensorflow) project, run its +configuration script (`./configure`), and copy the resulting +`.tf_configure.bazelrc` to `tensorflow-core-api`. This overrides the default options, and you can add to it manually (i.e. adding `build --copt="-g"` +to build with debugging info). + +### GPU Support + +Currently, due to build time constraints, the GPU binaries only support compute capacities 3.5 and 7.0. +To use with un-supported GPUs, you have to build it yourself, after changing the value [here](tensorflow-core/tensorflow-core-api/build.sh#L27), +setting the environment variable `TF_CUDA_COMPUTE_CAPABILITIES`, or configuring it in a bazel rc file ( +i.e. `build --action_env TF_CUDA_COMPUTE_CAPABILITIES="6.1"`). While this is far from ideal, we are working on getting more build resources, and for +now this is the best option. + +To build for GPU, pass `-Djavacpp.platform.extension=-gpu` to maven. By default, the CI options are used for the bazel build, see the above section +for more info. If you add `bazelrc` files, make sure the `TF_CUDA_COMPUTE_CAPABILITIES` value in them matches the value set elsewhere, as it will take +precedence if present. + +## Running Tests + +`ndarray` can be tested using the maven `test` target. `tensorflow-core` and `tensorflow-framework`, however, should be tested using +the `integration-test` target, due to the need to include native binaries. It will **not** be ran when using the `test` target of parent projects, but +will be ran by `install` or `integration-test`. If you see a `no jnitensorflow in java.library.path` error from tests it is likely because you're +running the wrong test target. + +### Native Crashes + +Occasionally tests will fail with a message like: + +``` +Failed to execute goal org.apache.maven.plugins:maven-surefire-plugin:2.22.0:test(default-test)on project tensorflow-core-api:There are test failures. + + Please refer to C:\mpicbg\workspace\tensorflow\java\tensorflow-core\tensorflow-core-api\target\surefire-reports for the individual test results. + Please refer to dump files(if any exist)[date]-jvmRun[N].dump,[date].dumpstream and[date]-jvmRun[N].dumpstream. + The forked VM terminated without properly saying goodbye.VM crash or System.exit called? + Command was cmd.exe/X/C"C:\Users\me\.jdks\adopt-openj9-1.8.0_275\jre\bin\java -jar C:\Users\me\AppData\Local\Temp\surefire236563113746082396\surefirebooter5751859365434514212.jar C:\Users\me\AppData\Local\Temp\surefire236563113746082396 2020-12-18T13-57-26_766-jvmRun1 surefire2445852067572510918tmp surefire_05950149004635894208tmp" + Error occurred in starting fork,check output in log + Process Exit Code:-1 + Crashed tests: + org.tensorflow.TensorFlowTest + org.apache.maven.surefire.booter.SurefireBooterForkException:The forked VM terminated without properly saying goodbye.VM crash or System.exit called? + Command was cmd.exe/X/C"C:\Users\me\.jdks\adopt-openj9-1.8.0_275\jre\bin\java -jar C:\Users\me\AppData\Local\Temp\surefire236563113746082396\surefirebooter5751859365434514212.jar C:\Users\me\AppData\Local\Temp\surefire236563113746082396 2020-12-18T13-57-26_766-jvmRun1 surefire2445852067572510918tmp surefire_05950149004635894208tmp" + Error occurred in starting fork,check output in log + Process Exit Code:-1 + Crashed tests: + org.tensorflow.TensorFlowTest + at org.apache.maven.plugin.surefire.booterclient.ForkStarter.fork(ForkStarter.java:671) + at org.apache.maven.plugin.surefire.booterclient.ForkStarter.fork(ForkStarter.java:533) + at org.apache.maven.plugin.surefire.booterclient.ForkStarter.run(ForkStarter.java:278) + at org.apache.maven.plugin.surefire.booterclient.ForkStarter.run(ForkStarter.java:244) +``` + +This is because the native code crashed (i.e. because of a segfault), and it should have created a dump file somewhere in the project that you can use +to tell what caused the issue. + +## Contributing + +### Formatting + +Java sources should be formatted according to the [Google style guide](https://google.github.io/styleguide/javaguide.html). It can be included +in [IntelliJ](https://github.com/google/styleguide/blob/gh-pages/intellij-java-google-style.xml) and +[Eclipse](https://github.com/google/styleguide/blob/gh-pages/eclipse-java-google-style.xml). +[Google's C++ style guide](https://google.github.io/styleguide/cppguide.html) should also be used for C++ code. + +### Code generation + +Code generation for `Ops` and related classes is done during `tensorflow-core-api`'s `compile` phase, using the annotation processor in +`tensorflow-core-generator`. If you change or add any operator classes (annotated with `org.tensorflow.op.annotation.Operator`), endpoint methods ( +annotated with `org.tensorflow.op.annotation.Endpoint`), or change the annotation processor, be sure to re-run a +`mvn install` in `tensorflow-core-api` (`-Pdev` is fine for this, it just needs to run the annotation processor). + +### Working with Bazel generation + +`tensorflow-core-api` uses Bazel-built C++ code generation to generate most of the `@Operator` classes. See [Native Builds](#native-builds) for +instructions on configuring the bazel build. To run the code generation, use the `//:java_op_generator` target. The resulting binary has good help +text (viewable in +[op_gen_main.cc](tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_gen_main.cc#L31-L48)). Generally, it should be called with arguments +that are something like: + +``` +bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/libtensorflow_cc.so --output_dir=src/gen/java --api_dirs=bazel-tensorflow-core-api/external/org_tensorflow/tensorflow/core/api_def/base_api,src/bazel/api_def +``` + +(called in `tensorflow-core-api`). diff --git a/LICENSE b/LICENSE index 786bd07395c..261eeb9e9f8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,3 @@ -Copyright 2020 The TensorFlow Authors. All rights reserved. - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ diff --git a/README.md b/README.md index bca18d1fe49..ec97ea648e6 100644 --- a/README.md +++ b/README.md @@ -34,26 +34,17 @@ The following describes the layout of the repository and its different artifacts * Intended audience: any developer who needs a Java n-dimensional array implementation, whether or not they use it with TensorFlow -## Building Sources -To build all the artifacts, simply invoke the command `mvn install` at the root of this repository (or -the Maven command of your choice). It is also possible to build artifacts with support for MKL enabled with -`mvn install -Djavacpp.platform.extension=-mkl` or CUDA with `mvn install -Djavacpp.platform.extension=-gpu` -or both with `mvn install -Djavacpp.platform.extension=-mkl-gpu`. +## Communication -When building this project for the first time in a given workspace, the script will attempt to download -the [TensorFlow runtime library sources](https://github.com/tensorflow/tensorflow) and build of all the native code -for your platform. This requires a valid environment for building TensorFlow, including the [bazel](https://bazel.build/) -build tool and a few Python dependencies (please read [TensorFlow documentation](https://www.tensorflow.org/install/source) -for more details). +This repository is maintained by TensorFlow JVM Special Interest Group (SIG). You can easily join the group +by subscribing to the [jvm@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/jvm) +mailing list, or you can simply send pull requests and raise issues to this repository. +There is also a [sig-jvm Gitter channel](https://gitter.im/tensorflow/sig-jvm). -This step can take multiple hours on a regular laptop. It is possible though to skip completely the native build if you are -working on a version that already has pre-compiled native artifacts for your platform [available on Sonatype OSS Nexus repository](#Snapshots). -You just need to activate the `dev` profile in your Maven command to use those artifacts instead of building them from scratch -(e.g. `mvn install -Pdev`). +## Building Sources -Note that modifying any source files under `tensorflow-core` may impact the low-level TensorFlow bindings, in which case a -complete build could be required to reflect the changes. +See [CONTRIBUTING.md](CONTRIBUTING.md#building). ## Using Maven Artifacts @@ -162,6 +153,4 @@ This table shows the mapping between different version of TensorFlow for Java an ## How to Contribute? -This repository is maintained by TensorFlow JVM Special Interest Group (SIG). You can easily join the group -by subscribing to the [jvm@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/jvm) -mailing list, or you can simply send pull requests and raise issues to this repository. +Contributions are welcome, guidelines are located in [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/ndarray/pom.xml b/ndarray/pom.xml index d228fbdb32a..4139b8b7929 100644 --- a/ndarray/pom.xml +++ b/ndarray/pom.xml @@ -80,7 +80,6 @@ 1 false -Xmx2G -XX:MaxPermSize=256m - false **/*Test.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index e4bdc53c713..7d0f0222bbe 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -18,8 +18,8 @@ package org.tensorflow.ndarray.impl.dimension; import java.util.Arrays; -import org.tensorflow.ndarray.index.Index; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Index; public class DimensionalSpace { @@ -35,24 +35,42 @@ public static DimensionalSpace create(Shape shape) { } public RelativeDimensionalSpace mapTo(Index[] indices) { - if (dimensions == null || indices.length > dimensions.length) { + if (dimensions == null) { throw new ArrayIndexOutOfBoundsException(); } int dimIdx = 0; + int indexIdx = 0; int newDimIdx = 0; int segmentationIdx = -1; long initialOffset = 0; - Dimension[] newDimensions = new Dimension[dimensions.length]; - while (dimIdx < indices.length) { + int newAxes = 0; + boolean seenEllipsis = false; + for (Index idx : indices) { + if (idx.isNewAxis()) { + newAxes += 1; + } + if (idx.isEllipsis()) { + if (seenEllipsis) { + throw new IllegalArgumentException("Only one ellipsis allowed"); + } else { + seenEllipsis = true; + } + } + } + int newLength = dimensions.length + newAxes; + + Dimension[] newDimensions = new Dimension[newLength]; + while (indexIdx < indices.length) { - if (indices[dimIdx].isPoint()) { + if (indices[indexIdx].isPoint()) { // When an index targets a single point in a given dimension, calculate the offset of this // point and cumulate the offset of any subsequent point as well long offset = 0; do { - offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]); - } while (++dimIdx < indices.length && indices[dimIdx].isPoint()); + offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]); + dimIdx++; + } while (++indexIdx < indices.length && indices[indexIdx].isPoint()); // If this is the first index, then the offset is the position of the whole dimension // space within the original one. If not, then we apply the offset to the last vectorial @@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { segmentationIdx = newDimIdx - 1; } + } else if (indices[indexIdx].isNewAxis()) { + long newSize; + if (dimIdx == 0) { + // includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues + // elsewhere + newSize = dimensions[0].numElements() * dimensions[0].elementSize(); + } else { + newSize = dimensions[dimIdx - 1].elementSize(); + } + + newDimensions[newDimIdx] = new Axis(1, newSize); + segmentationIdx = newDimIdx; // is this correct? + ++newDimIdx; + ++indexIdx; + } else if (indices[indexIdx].isEllipsis()) { + int remainingDimensions = dimensions.length - dimIdx; + int requiredDimensions = 0; + for (int i = indexIdx + 1; i < indices.length; i++) { + if (!indices[i].isNewAxis()) { + requiredDimensions++; + } + } + // while the number of dimensions left < the number of indices that consume axes + while (remainingDimensions > requiredDimensions) { + Dimension dim = dimensions[dimIdx++]; + if (dim.isSegmented()) { + segmentationIdx = newDimIdx; + } + newDimensions[newDimIdx++] = dim; + remainingDimensions--; + } + indexIdx++; } else { // Map any other index to the appropriate dimension of this space - Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]); + Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]); newDimensions[newDimIdx] = newDimension; if (newDimension.isSegmented()) { segmentationIdx = newDimIdx; } ++newDimIdx; + ++indexIdx; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index b38e33d5e22..9d3139f3248 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -39,4 +39,19 @@ public Dimension apply(Dimension dim) { private All() { } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public String toString() { + return All.class.getSimpleName() + "()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 5d92ee3286b..31ce021ddc8 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class At implements Index { @@ -27,22 +28,47 @@ public long numElements(Dimension dim) { @Override public long mapCoordinate(long coordinate, Dimension dim) { - return dim.positionOf(coord); // TODO validate coordinate is 0? + long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord; + return dim.positionOf(coord); } @Override public Dimension apply(Dimension dim) { - throw new IllegalStateException(); // FIXME? + if (!keepDim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); + } + + return dim.withIndex(this); } @Override public boolean isPoint() { - return true; + return !keepDim; } - At(long coord) { + At(long coord, boolean keepDim) { this.coord = coord; + this.keepDim = keepDim; } private final long coord; + private final boolean keepDim; + + @Override + public long begin() { + return coord; + } + + @Override + public long end() { + return coord + 1; + } + + @Override + public String toString() { + return new StringJoiner(", ", At.class.getSimpleName() + "(", ")") + .add("coord=" + coord) + .add("keepDim=" + keepDim) + .toString(); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java similarity index 61% rename from ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java rename to ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index 54f53853c32..d4085735df2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,26 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class Even implements Index { +final class Ellipsis implements Index { - static final Even INSTANCE = new Even(); + static final Ellipsis INSTANCE = new Ellipsis(); + + private Ellipsis() { + + } @Override public long numElements(Dimension dim) { - return (dim.numElements() >> 1) + (dim.numElements() % 2); + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } @Override public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate << 1; + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } - private Even() { + @Override + public boolean isEllipsis() { + return true; + } + + @Override + public String toString() { + return Ellipsis.class.getSimpleName() + "()"; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java deleted file mode 100644 index 7914d8faad5..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Flip implements Index { - - static final Flip INSTANCE = new Flip(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements(); - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return dim.numElements() - coordinate - 1; - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java deleted file mode 100644 index c541e8370b2..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class From implements Index { - - @Override - public long numElements(Dimension dim) { - return dim.numElements() - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - From(long start) { - this.start = start; - } - - private final long start; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index 00b411d0167..55c4e510748 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -15,6 +15,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; /** @@ -71,4 +72,19 @@ public boolean isPoint() { private final long stride; private final long count; private final long block; + + @Override + public String toString() { + return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")") + .add("start=" + start) + .add("stride=" + stride) + .add("count=" + count) + .add("block=" + block) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java index da6aa9049f6..617ca4d474b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -23,19 +23,16 @@ * An index used for slicing a view out of an N-dimensional array. * *

A slice, i.e. a reduced view, of an N-dimensional array is obtain by calling - * {@link NdArray#slice(Index...)}, given a list of indices - * that select which elements on a given dimension should be included/excluded - * from that view. + * {@link NdArray#slice(Index...)}, given a list of indices that select which elements on a given dimension should be + * included/excluded from that view. */ public interface Index { /** - * Returns the number of elements that can be retrieved using this index on the - * given dimension. + * Returns the number of elements that can be retrieved using this index on the given dimension. * *

An index that maps one-by-one all elements of the dimensions will return a value - * equal to {@code dim.numElements()}, while an index that only maps a subset of these - * will return a smaller value. + * equal to {@code dim.numElements()}, while an index that only maps a subset of these will return a smaller value. * * @param dim the indexed dimension * @return number of elements accessible @@ -43,8 +40,7 @@ public interface Index { long numElements(Dimension dim); /** - * Transforms an element coordinate to a new coordinate by applying this index to the - * given dimension. + * Transforms an element coordinate to a new coordinate by applying this index to the given dimension. * *

For example, if the coordinate is 0 and this index flips the {@code n} elements on this * dimension, then the returned value will be {@code n-1}. @@ -74,4 +70,62 @@ default Dimension apply(Dimension dim) { default boolean isPoint() { return false; } + + /** + * Returns true if this index is a new axis, adding a dimension of size 1 + */ + default boolean isNewAxis() { + return false; + } + + /** + * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible (and applying all() to + * them) + */ + default boolean isEllipsis() { + return false; + } + + /** + * Get whether the Index supports strided slice style indexing (using start, end, stride, and flags, i.e. TensorFlow's). + */ + default boolean isStridedSlicingCompliant() { + return true; + } + + /** + * Get the start of the index, for strided slice style indexing. + */ + default long begin() { + return 0; + } + + /** + * Get the end of the index, strided slice style indexing. + */ + default long end() { + return 0; + } + + /** + * Get the stride of the index, for strided slice style indexing. + */ + default long stride() { + return 1; + } + + /** + * Get whether the Index should start at the beginning of the dimension, for strided slice style indexing. + */ + default boolean beginMask() { + return false; + } + + /** + * Get whether the Index should end at the beginning of the dimension, for strided slice style indexing. + */ + default boolean endMask() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index abc72195c82..346ab705595 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -34,14 +34,14 @@ public final class Indices { * single element and therefore is excluded from the computation of the rank. * *

For example, given a 3D matrix on the axis [x, y, z], if - * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its - * number of elements is {@code x.numElements()} + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} * * @param coord coordinate of the element on the indexed axis * @return index */ public static Index at(long coord) { - return new At(coord); + return new At(coord, false); } /** @@ -58,7 +58,46 @@ public static Index at(NdArray coord) { if (coord.rank() > 0) { throw new IllegalRankException("Only scalars are accepted as a value index"); } - return new At(coord.getObject().longValue()); + return new At(coord.getObject().longValue(), false); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

When this index is applied to a given dimension, the dimension is resolved as a + * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. If {@code} + * keepDim is true, the dimension is collapsed down to one element. + * + *

For example, given a 3D matrix on the axis [x, y, z], if + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} + * + * @param coord coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + */ + public static Index at(long coord, boolean keepDim) { + return new At(coord, keepDim); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate is + * provided by an N-dimensional array. + *

+ * If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed. + * + * @param coord scalar indicating the coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) + */ + public static Index at(NdArray coord, boolean keepDim) { + if (coord.rank() > 0) { + throw new IllegalRankException("Only scalars are accepted as a value index"); + } + return new At(coord.getObject().longValue(), keepDim); } /** @@ -110,8 +149,7 @@ public static Index seq(NdArray coords) { } /** - * An index that returns only elements found at an even position in the - * original dimension. + * An index that returns only elements found at an even position in the original dimension. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code even()} returns x0, x2, ..., xn-2 @@ -119,12 +157,11 @@ public static Index seq(NdArray coords) { * @return index */ public static Index even() { - return Even.INSTANCE; + return step(2); } /** - * An index that returns only elements found at an odd position in the - * original dimension. + * An index that returns only elements found at an odd position in the original dimension. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code odd()} returns x1, x3, ..., xn-1 @@ -132,7 +169,7 @@ public static Index even() { * @return index */ public static Index odd() { - return Odd.INSTANCE; + return sliceFrom(1, 2); } /** @@ -141,30 +178,44 @@ public static Index odd() { *

For example, given a vector with {@code n} elements on the {@code x} axis, * {@code step(k)} returns x0, xk, xk*2, ... * - * @param stepLength the number of elements between each steps + * @param stride the number of elements between each steps + * @return index + */ + public static Index step(long stride) { + return new Step(stride); + } + + /** + * An index that returns only elements on a given dimension starting at a specific coordinate. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code from(k)} returns xk, xk+1, ..., xn-1 + * + * @param start coordinate of the first element of the sequence * @return index */ - public static Index step(long stepLength) { - return new Step(stepLength); + public static Index sliceFrom(long start) { + return sliceFrom(start, 1); } /** - * An index that returns only elements on a given dimension starting at a - * specific coordinate. + * An index that returns only elements on a given dimension starting at a specific coordinate, using the given + * stride. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code from(k)} returns xk, xk+1, ..., xn-1 * * @param start coordinate of the first element of the sequence + * @param stride the stride to use * @return index + * @see #slice(long, long, long) */ - public static Index from(long start) { - return new From(start); + public static Index sliceFrom(long start, long stride) { + return new SliceFrom(start, stride); } /** - * An index that returns only elements on a given dimension up to a - * specific coordinate. + * An index that returns only elements on a given dimension up to a specific coordinate. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code to(k)} returns x0, x1, ..., xk @@ -172,8 +223,23 @@ public static Index from(long start) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static Index to(long end) { - return new To(end); + public static Index sliceTo(long end) { + return sliceTo(end, 1); + } + + /** + * An index that returns only elements on a given dimension up to a specific coordinate, using the given stride. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code to(k)} returns x0, x1, ..., xk + * + * @param end coordinate of the last element of the sequence (exclusive) + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceTo(long end, long stride) { + return new SliceTo(end, stride); } /** @@ -187,7 +253,7 @@ public static Index to(long end) { * @return index */ public static Index range(long start, long end) { - return new Range(start, end); + return slice(start, end); } /** @@ -199,21 +265,99 @@ public static Index range(long start, long end) { * @return index */ public static Index flip() { - return Flip.INSTANCE; + return slice(null, null, -1); } - + /** - * An index that returns elements according to an hyperslab defined by {@code start}, - * {@code stride}, {@code count}, {@code block}. See {@link Hyperslab}. - * + * An index that returns elements according to an hyperslab defined by {@code start}, {@code stride}, {@code count}, + * {@code block}. See {@link Hyperslab}. + * * @param start Starting location for the hyperslab. * @param stride The number of elements to separate each element or block to be selected. * @param count The number of elements or blocks to select along the dimension. * @param block The size of the block selected from the dimension. - * * @return index */ public static Index hyperslab(long start, long stride, long count, long block) { return new Hyperslab(start, stride, count, block); } + + /** + * An index that inserts a new dimension of size 1 into the resulting array. + * + * @return index + */ + public static Index newAxis() { + return NewAxis.INSTANCE; + } + + /** + * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. + * + * @return index + */ + public static Index ellipsis() { + return Ellipsis.INSTANCE; + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end, long stride) { + return new Slice(start, end, stride); + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end, long stride) { + if (start == null && end == null) { + if (stride == 1) { + return Indices.all(); + } else { + return Indices.step(stride); + } + } else if (start == null) { + return Indices.sliceTo(end, stride); + } else if (end == null) { + return Indices.sliceFrom(start, stride); + } + + return slice(start.longValue(), end.longValue(), stride); + } + } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java similarity index 65% rename from ndarray/src/main/java/org/tensorflow/ndarray/index/To.java rename to ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java index 167d1c6865e..a68b1ed9ad1 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,17 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class To implements Index { +final class NewAxis implements Index { + + static final NewAxis INSTANCE = new NewAxis(); + + private NewAxis() { + + } @Override public long numElements(Dimension dim) { - return end; + return 1; } @Override @@ -30,9 +36,18 @@ public long mapCoordinate(long coordinate, Dimension dim) { return coordinate; } - To(long end) { - this.end = end; + @Override + public Dimension apply(Dimension dim) { + throw new IllegalStateException(); + } + + @Override + public boolean isNewAxis() { + return true; } - private final long end; + @Override + public String toString() { + return NewAxis.class.getSimpleName() + "()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java deleted file mode 100644 index 070331f1ffb..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Odd implements Index { - - static final Odd INSTANCE = new Odd(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements() >> 1; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return (coordinate << 1) + 1; - } - - private Odd() { - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java deleted file mode 100644 index e5d6003d87b..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Range implements Index { - - @Override - public long numElements(Dimension dim) { - return end - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - Range(long start, long end) { - this.start = start; - this.end = end; - } - - private final long start; - private final long end; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java index 41d37d05806..5b93e434e54 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.impl.dimension.Dimension; @@ -36,4 +37,16 @@ public long mapCoordinate(long coordinate, Dimension dim) { } private final NdArray coords; + + @Override + public String toString() { + return new StringJoiner(", ", Sequence.class.getSimpleName() + "(", ")") + .add("coords=" + coords) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java new file mode 100644 index 00000000000..1be4368261c --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -0,0 +1,89 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Slice implements Index { + + Slice(long start, long end, long stride) { + this.start = start; + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public long end() { + return end; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Slice.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long start; + private final long end; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java new file mode 100644 index 00000000000..c968a325cf7 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceFrom implements Index { + + SliceFrom(long start, long stride) { + this.start = start; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceFrom.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } + + private final long start; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java new file mode 100644 index 00000000000..761d1d52a3a --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceTo implements Index { + + SliceTo(long end, long stride) { + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long end() { + return end; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceTo.class.getSimpleName() + "(", ")") + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long end; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java index 725abd8f2e7..c9a21c507b6 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,27 +12,72 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class Step implements Index { + Step(long stride) { + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + @Override public long numElements(Dimension dim) { - return (dim.numElements() / stepLength) + 1; // FIXME always include element 0? + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); } @Override public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate * stepLength; + return start(dim) + stride * coordinate; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Step.class.getSimpleName() + "(", ")") + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive } - Step(long stepLength) { - this.stepLength = stepLength; + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } } - private final long stepLength; + private final long stride; } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java new file mode 100644 index 00000000000..6f92dab9b99 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java @@ -0,0 +1,205 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.index.Indices; + +public class IndexTest { + @Test + public void testNullConversions(){ + assertTrue(Indices.slice(null, 0L).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Indices.slice(null, 0L).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Indices.slice(null, null).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Indices.slice(0L, null).endMask(), + "Passed null for slice end but didn't set end mask"); + + assertTrue(Indices.slice(0L, null).endMask(), + "Passed null for slice end but didn't set end mask"); + + assertTrue(Indices.slice(null, null).endMask(), + "Passed null for slice end but didn't set end mask"); + } + + @Test + public void testNewaxis(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.all(), Indices.newAxis()); + + assertEquals(Shape.of(5, 4, 5, 1), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1, 0)); + assertEquals(4, slice1.getInt(0, 0, 4, 0)); + assertEquals(2, slice1.getInt(0, 1, 2, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.all()); + + assertEquals(Shape.of(5, 4, 1, 5), slice2.shape()); + assertEquals(0, slice2.getInt(0, 0, 0, 0)); + assertEquals(1, slice2.getInt(0, 0, 0, 1)); + assertEquals(4, slice2.getInt(0, 0, 0, 4)); + assertEquals(2, slice2.getInt(0, 1, 0, 2)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.newAxis(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(5, 1, 4, 5), slice3.shape()); + assertEquals(0, slice3.getInt(0, 0, 0, 0)); + assertEquals(1, slice3.getInt(0, 0, 0, 1)); + assertEquals(4, slice3.getInt(0, 0, 0, 4)); + assertEquals(2, slice3.getInt(0, 0, 1, 2)); + + IntNdArray slice4 = matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(1, 5, 4, 5), slice4.shape()); + assertEquals(0, slice4.getInt(0, 0, 0, 0)); + assertEquals(1, slice4.getInt(0, 0, 0, 1)); + assertEquals(4, slice4.getInt(0, 0, 0, 4)); + assertEquals(2, slice4.getInt(0, 0, 1, 2)); + + } + + @Test + public void testEllipsis(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.all()), + matrix3d.slice(Indices.at(0), Indices.ellipsis()) + ); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.at(0), Indices.ellipsis(), Indices.at(0)) + ); + + // newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.newAxis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.newAxis(), Indices.ellipsis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0), Indices.newAxis()), + matrix3d.slice(Indices.ellipsis(), Indices.at(0), Indices.newAxis()) + ); + } + + @Test + public void testSlice(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.sliceTo(3), Indices.all()); + + assertEquals(Shape.of(5, 3, 5), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1)); + assertEquals(2, slice1.getInt(0, 1, 2)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4)); + + assertEquals(Shape.of(5, 4, 3), slice2.shape()); + assertEquals(1, slice2.getInt(0, 0, 0)); + assertEquals(3, slice2.getInt(0, 0, 2)); + assertEquals(2, slice2.getInt(0, 1, 1)); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, -1))); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-4, -1))); + + assertEquals(Shape.of(5, 4, 0), matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4, -2)).shape()); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(4, 1, -2)); + + assertEquals(Shape.of(5, 4, 2), slice3.shape()); + assertEquals(4, slice3.getInt(0, 0, 0)); + assertEquals(2, slice3.getInt(0, 1, 1)); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, 1, -2))); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, -4, -2))); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(null, null, -1)); + + assertEquals(Shape.of(5, 4, 5), slice4.shape()); + assertEquals(4, slice4.getInt(0, 0, 0)); + assertEquals(3, slice4.getInt(0, 0, 1)); + assertEquals(2, slice4.getInt(0, 1, 2)); + } + + @Test + public void testAt(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)); + + assertEquals(Shape.of(5, 4), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(3)); + + assertEquals(Shape.of(5, 4), slice2.shape()); + assertEquals(3, slice2.getInt(0, 0)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3)); + + assertEquals(Shape.of(5, 4), slice3.shape()); + assertEquals(2, slice3.getInt(0, 0)); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3, true)); + + assertEquals(Shape.of(5, 4, 1), slice4.shape()); + assertEquals(2, slice4.getInt(0, 0, 0)); + } + +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java index 1c1d89680e7..26ac533daa8 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java @@ -24,11 +24,11 @@ import static org.tensorflow.ndarray.index.Indices.at; import static org.tensorflow.ndarray.index.Indices.even; import static org.tensorflow.ndarray.index.Indices.flip; -import static org.tensorflow.ndarray.index.Indices.from; +import static org.tensorflow.ndarray.index.Indices.sliceFrom; import static org.tensorflow.ndarray.index.Indices.odd; import static org.tensorflow.ndarray.index.Indices.range; import static org.tensorflow.ndarray.index.Indices.seq; -import static org.tensorflow.ndarray.index.Indices.to; +import static org.tensorflow.ndarray.index.Indices.sliceTo; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; @@ -212,13 +212,13 @@ public void slices() { assertEquals(val101, vector10_flip.getObject(3)); // Vector (1,0,[from 1]) from vector (1,0,*) - NdArray vector10_1toX = vector10X.slice(from(1)); + NdArray vector10_1toX = vector10X.slice(sliceFrom(1)); assertEquals(vector10_1toX.shape(), Shape.of(4)); assertEquals(val101, vector10_1toX.getObject(0)); assertEquals(val102, vector10_1toX.getObject(1)); // Vector (1,0,[to 1]) from vector (1,0,*) - NdArray vector10_Xto1 = vector10X.slice(to(2)); + NdArray vector10_Xto1 = vector10X.slice(sliceTo(2)); assertEquals(vector10_Xto1.shape(), Shape.of(2)); assertEquals(val100, vector10_Xto1.getObject(0)); assertEquals(val101, vector10_Xto1.getObject(1)); diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java b/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java index 8acfdff7721..fb7022bc830 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java @@ -38,7 +38,7 @@ import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; -@Fork(value = 0, jvmArgs = {"-Xms4G", "-Xmx4G"}) +@Fork(value = 1, jvmArgs = {"-Xms4G", "-Xmx4G"}) @BenchmarkMode(Mode.AverageTime) @Warmup(iterations = 3) @Measurement(iterations = 5) diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java index d5b5ca809a4..375f7643875 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java @@ -40,7 +40,7 @@ public void equalsAndHashCodeOnSlices() { {{3, 4}, {6, 7}} }); - assertTrue(vector1.equals(vector2.slice(Indices.from(2)))); + assertTrue(vector1.equals(vector2.slice(Indices.sliceFrom(2)))); assertTrue(vector1.equals(matrix1.get(1))); assertTrue(vector1.equals(matrix2.get(1).slice(Indices.even()))); assertTrue(matrix1.equals(matrix2.slice(Indices.all(), Indices.even()))); diff --git a/tensorflow-core/tensorflow-core-api/build.sh b/tensorflow-core/tensorflow-core-api/build.sh index 356a00db91d..e94efa850d8 100755 --- a/tensorflow-core/tensorflow-core-api/build.sh +++ b/tensorflow-core/tensorflow-core-api/build.sh @@ -24,7 +24,7 @@ fi if [[ "${EXTENSION:-}" == *gpu* ]]; then export BUILD_FLAGS="$BUILD_FLAGS --config=cuda" - export TF_CUDA_COMPUTE_CAPABILITIES="3.5,7.0" + export TF_CUDA_COMPUTE_CAPABILITIES="${TF_CUDA_COMPUTE_CAPABILITIES:-"3.5,7.0"}" if [[ -z ${TF_CUDA_PATHS:-} ]] && [[ -d ${CUDA_PATH:-} ]]; then # Work around some issue with Bazel preventing it from detecting CUDA on Windows export TF_CUDA_PATHS="$CUDA_PATH" diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..529b0d99c39 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -38,6 +38,7 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.index.Index; import org.tensorflow.op.core.Abort; import org.tensorflow.op.core.All; import org.tensorflow.op.core.Any; @@ -210,6 +211,7 @@ import org.tensorflow.op.core.StridedSlice; import org.tensorflow.op.core.StridedSliceAssign; import org.tensorflow.op.core.StridedSliceGrad; +import org.tensorflow.op.core.StridedSliceHelper; import org.tensorflow.op.core.Sum; import org.tensorflow.op.core.SwitchCond; import org.tensorflow.op.core.TemporaryVariable; @@ -345,10 +347,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +372,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** @@ -1831,6 +1833,19 @@ public Constant constant(Shape shape, IntDataBuffer data) { return Constant.tensorOf(scope, shape, data); } + /** + * Creates a scalar of {@code type}, with the value of {@code number}. + * {@code number} may be truncated if it does not fit in the target type. + * + * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) + * @param number the value of the tensor + * @return a constant of the passed type + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown. + */ + public Constant constant(Class type, Number number) { + return Constant.tensorOf(scope, type, number); + } + /** * Create a {@link TString} constant with data from the given buffer, using the given encoding. * @@ -1876,6 +1891,20 @@ public Constant constantOf(T tensor) { return Constant.create(scope, tensor); } + /** + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. + * {@code number} may be truncated if it does not fit in the target type. + * + * @param toMatch the operand providing the target type + * @param number the value of the tensor + * @return a constant with the same type as {@code toMatch} + * @see Ops#constant(Class, Number) + * @throws IllegalArgumentException if the type is unknown (which should be impossible). + */ + public Constant constantOfSameType(Operand toMatch, Number number) { + return Constant.tensorOfSameType(scope, toMatch, number); + } + /** * This op consumes a lock created by `MutexLock`. *

@@ -5894,6 +5923,56 @@ public StopGradient stopGradient(Operand input) { return StopGradient.create(scope, input); } + /** + * Return a strided slice from `input`. + *

+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input` + * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this + * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification + * entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of + * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice. + *

+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension. + * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)} + * produces a {@code (1, 3, 4)} tensor. + *

+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify + * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which + * represents the index of the first value to select while {@code end} represents the index of the last value to select + * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension, + * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}. + * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end} + * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2} + * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the + * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to + * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and + * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension + * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4]; + * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}. + *

+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For + * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor. + * The dimension can be kept with size one using {@link Indices#at(long, boolean)}. + *

+ * These semantics generally follow NumPy's indexing semantics, which can be found here: + * https://numpy.org/doc/stable/reference/arrays.indexing.html + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` Only one ellipsis. + * + * @param scope current scope + * @param data type for {@code output()} output + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSlice + * @see Indices + */ + public StridedSlice stridedSlice(Operand input, Index... indices) { + return StridedSliceHelper.stridedSlice(scope, input, indices); + } + /** * Return a strided slice from `input`. *

@@ -6006,6 +6085,28 @@ public StridedSlice stridedSlice(Operand return StridedSlice.create(scope, input, begin, end, strides, options); } + /** + * Assign `value` to the sliced l-value reference of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice + * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by + * the slice of `ref`. + * + * @param data type for {@code outputRef()} output + * @param scope current scope + * @param ref the tensor to assign to. + * @param value the value to assign. + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSliceAssign + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) + */ + public StridedSliceAssign stridedSliceAssign(Operand ref, + Operand value, Index... indices) { + return StridedSliceHelper.stridedSliceAssign(scope, ref, value, indices); + } + /** * Assign `value` to the sliced l-value reference of `ref`. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 09e5a47f8fd..a5c2df84026 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -68,6 +68,11 @@ public String type() { return type; } + @Override + public EagerSession env() { + return session; + } + @Override public int numOutputs() { return outputHandles.length; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 37f3af7ca26..a865300bc5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -75,6 +75,7 @@ public EagerOperation build() { @Override public EagerOperationBuilder addInput(Output input) { + session.checkInput(input); addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle()); return this; } @@ -83,6 +84,7 @@ public EagerOperationBuilder addInput(Output input) { public EagerOperationBuilder addInputList(Output[] inputs) { TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length]; for (int i = 0; i < inputs.length; ++i) { + session.checkInput(inputs[i]); inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle(); } addInputList(opHandle, inputHandles); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 96ef5228a4f..75bc12b5a6c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -27,6 +27,7 @@ import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; @@ -297,6 +298,13 @@ public boolean isOpEnabled(String opType) { } } + @Override + public void checkInput(Op input) { + if (!input.env().isEager()) { + throw new IllegalArgumentException("Can't use graph operation " + input + " in eager mode."); + } + } + TFE_Context nativeHandle() { checkSession(); return nativeHandle; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index a7a3363f690..d5389bcd0ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -15,7 +15,11 @@ package org.tensorflow; -/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ +import org.tensorflow.op.Op; + +/** + * Defines an environment for creating and executing TensorFlow {@link Operation}s. + */ public interface ExecutionEnvironment { enum Types { @@ -36,13 +40,23 @@ enum Types { /** * Returns true if the given operation is valid in this execution environment. + * * @param opType The op to check. * @return Whether the given operation is valid in this execution environment. */ - default boolean isOpEnabled(String opType){ + default boolean isOpEnabled(String opType) { return true; } + /** + * Checks that {@code input} is valid to use as an input in this execution environment. Throws {@link + * IllegalArgumentException} if not. + * + * @param input The op to check + * @throws IllegalArgumentException if input can't be used as an input in this execution environment. + */ + void checkInput(Op input); + /** * Get the type of this environment (from the `Environments` enumeration. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index f2717f263eb..988683895c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -158,6 +158,17 @@ public Types environmentType() { return Types.GRAPH; } + @Override + public void checkInput(Op input) { + if (input.env().isEager()) { + throw new IllegalArgumentException( + "Input " + input + " was from an eager session, can't use in a graph. Use tf.constantOf(input.asTensor())"); + } + if (input.env() != this) { + throw new IllegalArgumentException("Input " + input + " was from a different graph, can't use."); + } + } + /** * Import a representation of a TensorFlow graph. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index e1255748c3b..fbad92160a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -73,6 +73,13 @@ public String type() { } } + @Override + public Graph env() { + try (Graph.Reference r = graph.ref()) { + return graph; + } + } + @Override public int numOutputs() { Graph.Reference r = graph.ref(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 9c0f011bab4..72858ece572 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -92,6 +92,11 @@ public GraphOperationBuilder addControlInput(Operation control) { throw new IllegalArgumentException( "Only GraphOperation instances can be used as control inputs"); } + + if (control.env() != graph) { + throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use."); + } + Graph.Reference r = graph.ref(); try { addControlInput(unsafeNativeHandle, ((GraphOperation) control).getUnsafeNativeHandle()); @@ -103,6 +108,7 @@ public GraphOperationBuilder addControlInput(Operation control) { @Override public GraphOperationBuilder addInput(Output input) { + graph.checkInput(input); Graph.Reference r = graph.ref(); try { addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index()); @@ -114,6 +120,10 @@ public GraphOperationBuilder addInput(Output input) { @Override public GraphOperationBuilder addInputList(Output[] inputs) { + for (Output input : inputs) { + graph.checkInput(input); + } + Graph.Reference r = graph.ref(); try { TF_Operation[] opHandles = new TF_Operation[inputs.length]; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java index 1cc175da161..b47eee6850c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operation.java @@ -25,16 +25,24 @@ */ public interface Operation { - /** Returns the full name of the Operation. */ + /** + * Returns the full name of the Operation. + */ String name(); /** - * Returns the type of the operation, i.e., the name of the computation performed by the - * operation. + * Returns the type of the operation, i.e., the name of the computation performed by the operation. */ String type(); - /** Returns the number of tensors produced by this operation. */ + /** + * Returns the execution environment this operation was created in. + */ + ExecutionEnvironment env(); + + /** + * Returns the number of tensors produced by this operation. + */ int numOutputs(); /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index ea32d1fff13..66b4dad4132 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -15,9 +15,11 @@ */ package org.tensorflow; +import java.util.HashMap; import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.framework.TensorShapeProto; @@ -32,6 +34,16 @@ public class Signature { /** The default signature key, when not provided */ public static final String DEFAULT_KEY = "serving_default"; + public static class TensorDescription { + public final DataType dataType; + public final Shape shape; + + public TensorDescription(DataType dataType, Shape shape) { + this.dataType = dataType; + this.shape = shape; + } + } + /** * Builds a new function signature. */ @@ -174,6 +186,32 @@ public String toString() { return strBuilder.toString(); } + private Map buildTensorDescriptionMap(Map dataMapIn) { + Map dataTypeMap = new HashMap<>(); + dataMapIn.forEach((a, b) -> { + long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); + Shape tensorShape = Shape.of(tensorDims); + dataTypeMap.put(a, new TensorDescription(b.getDtype(), + tensorShape)); + }); + return dataTypeMap; + } + + /** + * Returns the names of the inputs in this signature mapped to their expected data type and shape + * @return + */ + public Map getInputs() { + return buildTensorDescriptionMap(signatureDef.getInputsMap()); + } + + /** + * Returns the names of the outputs in this signature mapped to their expected data type and shape + */ + public Map getOutputs() { + return buildTensorDescriptionMap(signatureDef.getOutputsMap()); + } + Signature(String key, SignatureDef signatureDef) { this.key = key; this.signatureDef = signatureDef; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java index 40b54393c60..6051623414f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Op.java @@ -15,6 +15,7 @@ package org.tensorflow.op; +import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operation; /** @@ -48,4 +49,11 @@ public interface Op { * @return an {@link Operation} */ Operation op(); + + /** + * Return the execution environment this op was created in. + */ + default ExecutionEnvironment env() { + return op().env(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index f0b739e074f..73fa340a487 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -16,14 +16,12 @@ package org.tensorflow.op; import java.util.ArrayList; - import org.tensorflow.DeviceSpec; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.OperationBuilder; /** - * Manages groups of related properties when creating Tensorflow Operations, such as a common name - * prefix. + * Manages groups of related properties when creating Tensorflow Operations, such as a common name prefix. * *

A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user * code initializes a {@code Scope} and provides it to Operation building classes. For example: @@ -88,7 +86,9 @@ public Scope(ExecutionEnvironment env) { this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build()); } - /** Returns the execution environment used by this scope. */ + /** + * Returns the execution environment used by this scope. + */ public ExecutionEnvironment env() { return env; } @@ -97,8 +97,7 @@ public ExecutionEnvironment env() { * Returns a new scope where added operations will have the provided name prefix. * *

Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual - * name will be unique in the returned scope. All other properties are inherited from the current - * scope. + * name will be unique in the returned scope. All other properties are inherited from the current scope. * *

The child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} * @@ -129,7 +128,8 @@ public Scope withName(String opName) { /** * Return a new scope that uses the provided device specification for an op. * - *

Operations created within this scope will place the created operations on the device(s) matching the provided spec. + *

Operations created within this scope will place the created operations on the device(s) matching the provided + * spec. * * @param deviceSpec device specification for an operator in the returned scope * @return a new Scope that uses opName for operations. @@ -151,8 +151,8 @@ public Scope withDevice(DeviceSpec deviceSpec) { * } * *

Note: if you provide a composite operator building class (i.e, a class that creates a - * set of related operations by calling other operator building code), the provided name will act - * as a subscope to all underlying operators. + * set of related operations by calling other operator building code), the provided name will act as a subscope to all + * underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. @@ -180,11 +180,15 @@ private Scope( * @return a new scope with the provided control dependencies */ public Scope withControlDependencies(Iterable controls) { + for (Op control : controls) { + env.checkInput(control); + } return new Scope(env, nameScope, controls, deviceSpec); } /** - * Applies device specification and adds each Operand in controlDependencies as a control input to the provided builder. + * Applies device specification and adds each Operand in controlDependencies as a control input to the provided + * builder. * * @param builder OperationBuilder to add control inputs and device specification to */ @@ -210,7 +214,9 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) { private final NameScope nameScope; private final DeviceSpec deviceSpec; - /** Returns device string from the scope. */ + /** + * Returns device string from the scope. + */ public String getDeviceString() { return deviceSpec.toString(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 918f9083923..497ee5f2d46 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -20,18 +20,6 @@ import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.Tensor; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.BooleanDataBuffer; -import org.tensorflow.ndarray.buffer.ByteDataBuffer; -import org.tensorflow.ndarray.buffer.DataBuffer; -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.IntDataBuffer; -import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -40,14 +28,30 @@ import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.op.Ops; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -1277,6 +1281,67 @@ public static Constant tensorOf(Scope scope, Shape shape) { return vectorOf(scope, shape.asArray()); } + /** + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not + * fit in the target type. + * + * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) + * @param number the value of the tensor + * @return a constant of the passed type + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or + * unknown. + */ + @SuppressWarnings("unchecked") + @Endpoint + public static Constant tensorOf(Scope scope, Class type, Number number) { + if (type.equals(TBfloat16.class)) { + try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat64.class)) { + try (TFloat64 tensor = TFloat64.scalarOf(number.doubleValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat32.class)) { + try (TFloat32 tensor = TFloat32.scalarOf(number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat16.class)) { + try (TFloat16 tensor = TFloat16.scalarOf(number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt64.class)) { + try (TInt64 tensor = TInt64.scalarOf(number.longValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt32.class)) { + try (TInt32 tensor = TInt32.scalarOf(number.intValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TUint8.class)) { + try (TUint8 tensor = TUint8.scalarOf(number.byteValue())) { + return (Constant) create(scope, tensor); + } + } else { + throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); + } + } + + /** + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. + * + * @param toMatch the operand providing the target type + * @param number the value of the tensor + * @return a constant with the same type as {@code toMatch} + * @throws IllegalArgumentException if the type is unknown (which should be impossible). + * @see Ops#constant(Class, Number) + */ + @Endpoint(name = "constantOfSameType") + public static Constant tensorOfSameType(Scope scope, Operand toMatch, Number number) { + return tensorOf(scope, toMatch.type(), number); + } + /** * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without * issue. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java new file mode 100644 index 00000000000..e97934ee312 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -0,0 +1,221 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.ndarray.index.Index; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Helper endpoint methods for Python like indexing. + * + * @see org.tensorflow.ndarray.index.Indices + */ +@Operator +public abstract class StridedSliceHelper { + + static class StridedSliceArgs { + + final int[] begin; + final int[] end; + final int[] strides; + final long beginMask; + final long endMask; + final long ellipsisMask; + final long newAxisMask; + final long shrinkAxisMask; + + private StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, long endMask, long ellipsisMask, + long newAxisMask, long shrinkAxisMask) { + this.begin = begin; + this.end = end; + this.strides = strides; + this.beginMask = beginMask; + this.endMask = endMask; + this.ellipsisMask = ellipsisMask; + this.newAxisMask = newAxisMask; + this.shrinkAxisMask = shrinkAxisMask; + } + } + + static StridedSliceArgs mergeIndexes(Index[] indices) { + int[] begin = new int[indices.length]; + int[] end = new int[indices.length]; + int[] strides = new int[indices.length]; + long beginMask = 0; + long endMask = 0; + long ellipsisMask = 0; + long newAxisMask = 0; + long shrinkAxisMask = 0; + + for (int i = 0; i < indices.length; i++) { + Index idx = indices[i]; + + if (idx == null) { + idx = Indices.all(); + } + + if (!idx.isStridedSlicingCompliant()) { + throw new UnsupportedOperationException("Index " + idx + " is not supported for Tensors"); + } + + begin[i] = (int) idx.begin(); + if (begin[i] != idx.begin()) { + throw new IllegalArgumentException( + "Can't convert long begin value to int for index " + idx + ": Out of bounds"); + } + + end[i] = (int) idx.end(); + if (end[i] != idx.end()) { + throw new IllegalArgumentException("Can't convert long end value to int for index " + idx + ": Out of bounds"); + } + + strides[i] = (int) idx.stride(); + if (strides[i] != idx.stride()) { + throw new IllegalArgumentException( + "Can't convert long stride value to int for index " + idx + ": Out of bounds"); + } + + if (idx.beginMask()) { + beginMask |= 1L << i; + } + + if (idx.endMask()) { + endMask |= 1L << i; + } + + if (idx.isEllipsis()) { + if (ellipsisMask != 0) { + throw new IllegalArgumentException("Can not have two ellipsis in a slice"); + } + ellipsisMask |= 1L << i; + } + + if (idx.isNewAxis()) { + newAxisMask |= 1L << i; + } + + if (idx.isPoint()) { + shrinkAxisMask |= 1L << i; + } + } + + return new StridedSliceArgs(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + } + + /** + * Return a strided slice from `input`. + *

+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input` + * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this + * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification + * entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of + * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice. + *

+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension. + * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)} + * produces a {@code (1, 3, 4)} tensor. + *

+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify + * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which + * represents the index of the first value to select while {@code end} represents the index of the last value to select + * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension, + * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}. + * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end} + * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2} + * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the + * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to + * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and + * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension + * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4]; + * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}. + *

+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For + * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor. + * The dimension can be kept with size one using {@link Indices#at(long, boolean)}. + *

+ * These semantics generally follow NumPy's indexing semantics, which can be found here: + * https://numpy.org/doc/stable/reference/arrays.indexing.html + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` Only one ellipsis. + * + * @param scope current scope + * @param data type for {@code output()} output + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSlice + * @see Indices + */ + @Endpoint(name = "stridedSlice") + public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { + StridedSliceArgs args = mergeIndexes(indices); + return StridedSlice.create( + scope, + input, + Constant.vectorOf(scope, args.begin), + Constant.vectorOf(scope, args.end), + Constant.vectorOf(scope, args.strides), + StridedSlice.beginMask(args.beginMask), + StridedSlice.endMask(args.endMask), + StridedSlice.ellipsisMask(args.ellipsisMask), + StridedSlice.newAxisMask(args.newAxisMask), + StridedSlice.shrinkAxisMask(args.shrinkAxisMask) + ); + } + + /** + * Assign `value` to the sliced l-value reference of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice + * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by + * the slice of `ref`. + * + * @param data type for {@code outputRef()} output + * @param scope current scope + * @param ref the tensor to assign to. + * @param value the value to assign. + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSliceAssign + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) + */ + @Endpoint(name = "stridedSliceAssign") + public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, + Operand value, Index... indices) { + StridedSliceArgs args = mergeIndexes(indices); + return StridedSliceAssign.create( + scope, + ref, + Constant.vectorOf(scope, args.begin), + Constant.vectorOf(scope, args.end), + Constant.vectorOf(scope, args.strides), + value, + StridedSliceAssign.beginMask(args.beginMask), + StridedSliceAssign.endMask(args.endMask), + StridedSliceAssign.ellipsisMask(args.ellipsisMask), + StridedSliceAssign.newAxisMask(args.newAxisMask), + StridedSliceAssign.shrinkAxisMask(args.shrinkAxisMask) + ); + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index bbb9e23ec90..33ae979ccbd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -23,7 +23,6 @@ import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Constant; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt32; @@ -31,22 +30,6 @@ /** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ public class GraphOperationBuilderTest { - @Test - public void failWhenMixingOperationsOnDifferentGraphs() { - try (Graph g1 = new Graph(); - Graph g2 = new Graph()) { - Ops tf = Ops.create(g1); - Constant c1 = tf.constant(3); - tf.math.add(c1, c1); - try { - Ops tf2 = Ops.create(g2); - tf2.math.add(c1, c1); - } catch (Exception e) { - fail(e.toString()); - } - } - } - @Test public void failOnUseAfterBuild() { try (Graph g = new Graph(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java index e1436358a68..c9740ce4a6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -14,11 +14,14 @@ ==============================================================================*/ package org.tensorflow; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; - import org.junit.jupiter.api.Test; +import org.tensorflow.Signature.TensorDescription; import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.DataType; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; public class SignatureTest { @@ -43,6 +46,29 @@ public void cannotDuplicateInputOutputNames() { } } + @Test + public void getInputsAndOutputs() { + Ops tf = Ops.create(); + Signature builder = Signature.builder() + .input("x", tf.constant(10.0f)) + .output("y", tf.constant(new float[][] {{10.0f, 30.0f}})) + .output("z", tf.constant(20.0f)).build(); + + Map inputs = builder.getInputs(); + assertEquals(inputs.size(), 1); + + Map outputs = builder.getOutputs(); + assertEquals(outputs.size(), 2); + + assertEquals(outputs.get("y").dataType, DataType.DT_FLOAT); + assertEquals(outputs.get("z").dataType, DataType.DT_FLOAT); + assertArrayEquals(outputs.get("y").shape.asArray(), new long [] {1,2}); + assertArrayEquals(outputs.get("z").shape.asArray(), new long [] {}); + + Signature emptySignature = Signature.builder().build(); + assertEquals(emptySignature.getInputs().size(), 0); + } + @Test public void emptyMethodNameConvertedToNull() { Signature signature = Signature.builder().key("f").build(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java new file mode 100644 index 00000000000..b2fbc1e794a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java @@ -0,0 +1,108 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; + +/** + * Tests for using Operands in different environments + */ +public class WrongEnvTest { + + /** + * Should work fine + */ + @Test + public void testTwoEagers() { + try (EagerSession e1 = EagerSession.create(); + EagerSession e2 = EagerSession.create()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + try (TInt32 tensor = c.asTensor()) { + assertEquals(11, tensor.getInt()); + } + + } + } + + @Test + public void testEagerInGraph() { + try (EagerSession e1 = EagerSession.create(); + Graph e2 = new Graph()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + fail(); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("was from an eager session, can't use in a graph")); + } + } + + @Test + public void testGraphInEager() { + try (Graph e1 = new Graph(); + EagerSession e2 = EagerSession.create()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + fail(); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Can't use graph operation")); + } + } + + @Test + public void testTwoGraphs() { + try (Graph e1 = new Graph(); + Graph e2 = new Graph()) { + Ops tf1 = Ops.create(e1); + Ops tf2 = Ops.create(e2); + + Operand a = tf1.constant(5); + Operand b = tf2.constant(6); + + Operand c = tf2.math.add(a, b); + + fail(); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("was from a different graph")); + } + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 5dd6903d913..6df73261867 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -18,15 +18,19 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.io.IOException; - import org.junit.jupiter.api.Test; import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.op.Ops; -import org.tensorflow.op.Scope; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.buffer.DataBuffers; @@ -34,19 +38,20 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.LongNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TNumber; public class ConstantTest { + private static final float EPSILON = 1e-7f; @Test @@ -56,7 +61,7 @@ public void createInts() { IntNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); @@ -164,4 +169,28 @@ public void createFromTensorsInEagerMode() throws IOException { assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor()); } } + + private static void testCreateFromNumber(Ops tf, Class type) { + Operand constant = tf.constant(type, 10); + assertEquals(type, constant.type()); + + try (TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor()) { + assertEquals(10.0, t.getDouble()); + } + } + + @Test + public void createFromNumber() { + try (EagerSession s = EagerSession.create()) { + Ops tf = Ops.create(s); + + testCreateFromNumber(tf, TBfloat16.class); + testCreateFromNumber(tf, TFloat64.class); + testCreateFromNumber(tf, TFloat32.class); + testCreateFromNumber(tf, TFloat16.class); + testCreateFromNumber(tf, TInt64.class); + testCreateFromNumber(tf, TInt32.class); + testCreateFromNumber(tf, TUint8.class); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java new file mode 100644 index 00000000000..6e86573b7cf --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.ndarray.index.Index; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TFloat32; + +public class IndexingTest { + + // [2, 1:2, :, tf.newaxis, ..., :4, 4::2] + private static final Index[] slice = new Index[]{ + Indices.at(2), + Indices.at(1, true), + Indices.all(), + Indices.newAxis(), + Indices.ellipsis(), + Indices.sliceTo( 4), + Indices.sliceFrom(4, 2) + }; + + @Test + public void testIndexMerge() { + StridedSliceHelper.StridedSliceArgs args = StridedSliceHelper.mergeIndexes(slice); + + assertArrayEquals(new int[]{2, 1, 0, 0, 0, 0, 4}, args.begin); + assertArrayEquals(new int[]{3, 2, 0, 0, 0, 4, 0}, args.end); + assertArrayEquals(new int[]{1, 1, 1, 1, 1, 1, 2}, args.strides); + assertEquals(0b0100100, args.beginMask); + assertEquals(0b1000100, args.endMask); + assertEquals(0b0010000, args.ellipsisMask); + assertEquals(0b0001000, args.newAxisMask); + assertEquals(0b0000001, args.shrinkAxisMask); + + } + + @Test + public void testStridedSliceIndex(){ + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); + StridedSlice output = StridedSliceHelper.stridedSlice(scope, op, slice); + try (TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), "Slice index didn't match expected (Python)"); + } + } + } + +} diff --git a/tensorflow-framework/pom.xml b/tensorflow-framework/pom.xml index 74c4c2c2084..71cb99bbb95 100644 --- a/tensorflow-framework/pom.xml +++ b/tensorflow-framework/pom.xml @@ -94,7 +94,6 @@ 1 false -Xmx2G -XX:MaxPermSize=256m - false **/*Test.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 290e4e80b57..894bd073758 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -62,7 +62,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see VarianceScaling.Distribution diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 9b1a0887af0..3a91b72b0d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -57,7 +57,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 363291fa5cc..5aac163c1e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. x=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means + * that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -199,13 +202,14 @@ public CategoricalCrossentropy( * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" - * and axis=1 corresponds to data format "Channels First". - * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} + * and axis=1 corresponds to data format "Channels First". {@link + * Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( @@ -242,13 +246,12 @@ public CategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index f592c19f8bb..73837ed1756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1.

+ *

labels values are expected to be 0 or 1. * *

Standalone usage: * @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 137c7025c04..0a18d93caf3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -22,12 +22,13 @@ /** * Computes the cosine similarity between labels and predictions. * - *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 - * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to - * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you - * try to maximize the proximity between predictions and targets. If either labels or predictions is - * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and - * targets. + *

Note that it is a number between -1 and 1. When it is a negative + * number between -1 and 0, 0 indicates orthogonality and + * values closer to -1indicate greater similarity. The values closer to 1 + * indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels or + * predictions is a zero vector, cosine similarity will be 0 regardless of + * the proximity between predictions and targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) * @@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; - private final int axis; + private final int[] axis; /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis @@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) { this(tf, null, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a + * Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, int[] axis) { + + this(tf, null, axis, DEFAULT_REDUCTION); + } /** * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} @@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) { this(tf, name, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, String name, int[] axis) { + + this(tf, name, axis, DEFAULT_REDUCTION); + } + /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an * axis of {@link #DEFAULT_AXIS} @@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) { */ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + this(tf, null, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + this(tf, null, axis, reduction); } @@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { * @param reduction Type of Reduction to apply to the loss. */ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { + this(tf, name, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { super(tf, name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + losses = tf.math.neg(losses); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 88b4a7aa056..d4c350ef06c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,15 +18,16 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0)

. + *

loss = maximum(1 - labels * predictions, 0). * - *

labels values are expected to be -1 or 1. - * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -116,21 +117,19 @@ public Hinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand tLabels = cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 6d3e3f0c2ac..b1aee1b0656 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -89,6 +89,7 @@ public Huber(Ops tf) { * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Huber(Ops tf, String name) { this(tf, name, DELTA_DEFAULT, Reduction.AUTO); @@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) { * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Huber(Ops tf, String name, Reduction reduction) { @@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) { * Creates a Huber Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ @@ -130,8 +132,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 8cf3db8d518..2aa1f72092b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 1669669a768..a11d582e527 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -77,6 +77,7 @@ public LogCosh(Ops tf) { * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public LogCosh(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -96,7 +97,7 @@ public LogCosh(Ops tf, Reduction reduction) { * Creates a LogCosh Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public LogCosh(Ops tf, String name, Reduction reduction) { @@ -105,8 +106,8 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index ae33d5dfa37..cdd35d28aba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -62,10 +62,10 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels * @param predictions the predictions * @param The data type of the predictions and loss. - * @param The data type of the labels. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { return call(labels, predictions, null); } @@ -82,11 +82,10 @@ public Operand call(Operand labels, * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 0d25bd5e7e2..9aa94cf7fcf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -48,11 +48,10 @@ public class Losses { * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute error */ - public static Operand meanAbsoluteError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsoluteError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -70,11 +69,10 @@ public static Operand meanAbsoluteErro * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared error */ - public static Operand meanSquaredError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -91,11 +89,10 @@ public static Operand meanSquaredError * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute percentage error */ - public static Operand meanAbsolutePercentageError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsolutePercentageError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -105,8 +102,10 @@ public static Operand meanAbsolutePerc tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum( + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul( + cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -118,11 +117,10 @@ public static Operand meanAbsolutePerc * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared logarithmic percentage error */ - public static Operand meanSquaredLogarithmicError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredLogarithmicError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -152,8 +150,12 @@ public static Operand meanSquaredLogar * @param the data type of the predictions and labels * @return the binary crossentropy loss. */ - public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + public static Operand binaryCrossentropy( + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -181,7 +183,7 @@ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible + /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { // TODO - this does not work // TODO output = backtrackIdentity(output); @@ -218,16 +220,17 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. */ - public static Operand categoricalCrossentropy( + public static Operand categoricalCrossentropy( Ops tf, - Operand labels, + Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing, @@ -283,8 +286,8 @@ public static Operand categoricalCross * @param the data type of the predictions and labels * @return the categorical hinge loss */ - public static Operand categoricalHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand categoricalHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -329,8 +332,8 @@ public static Operand categoricalHinge * @param the data type of the predictions and labels * @return the cosine similarity loss */ - public static Operand cosineSimilarity( - Ops tf, Operand labels, Operand predictions, int axis) { + public static Operand cosineSimilarity( + Ops tf, Operand labels, Operand predictions, int[] axis) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -339,8 +342,7 @@ public static Operand cosineSimilarity tLabels = l2Normalize(tf, tLabels, axis); predictions = l2Normalize(tf, predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); - Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.neg(sum); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } /** @@ -355,8 +357,8 @@ public static Operand cosineSimilarity * @param the data type of the predictions and labels * @return the hinge loss */ - public static Operand hinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand hinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -391,8 +393,8 @@ public static Operand hinge( * @param the data type of the predictions and labels * @return the Huber loss */ - public static Operand huber( - Ops tf, Operand labels, Operand predictions, float delta) { + public static Operand huber( + Ops tf, Operand labels, Operand predictions, float delta) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -422,8 +424,8 @@ public static Operand huber( * @see Kullback?Leibler * divergence */ - public static Operand kullbackLeiblerDivergence( - Ops tf, Operand labels, Operand predictions) { + public static Operand kullbackLeiblerDivergence( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -452,8 +454,8 @@ public static Operand kullbackLeiblerD * @param the data type of the predictions and labels * @return the hyperbolic cosine divergence loss */ - public static Operand logCosh( - Ops tf, Operand labels, Operand predictions) { + public static Operand logCosh( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -480,8 +482,8 @@ public static Operand logCosh( * @param the data type of the predictions and labels * @return the Poisson loss */ - public static Operand poisson( - Ops tf, Operand labels, Operand predictions) { + public static Operand poisson( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -507,8 +509,12 @@ public static Operand poisson( * @param the data type of the predictions and labels * @return the sparse categorical crossentropy loss */ - public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + public static Operand sparseCategoricalCrossentropy( + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -553,7 +559,7 @@ public static Operand sparseCategorica int labelsRank = labelsShape.numDimensions(); boolean updateShape = labelsRank != predictionsRank - 1; - if (updateShape) { // TODO check to see if this is right + if (updateShape) { Shape newShape = labelsShape.take(labelsRank - 1); iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = @@ -584,8 +590,8 @@ public static Operand sparseCategorica * @param the data type of the predictions and labels * @return the squared hinge loss */ - public static Operand squaredHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand squaredHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -649,14 +655,14 @@ private static Operand smoothCategoricalLabels( * @param tf The TensorFlow Ops * @param x the input * @param axis Dimension along which to normalize. + * @param the data type for the input and the result * @return the normalized values based on L2 norm */ - public static Operand l2Normalize(Ops tf, Operand x, int axis) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index a2d5d5f8efc..03a3cf70110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -95,8 +95,8 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 49133df610b..6c5242df4f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -95,8 +95,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 2a6c2be885e..f975db55c44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -95,8 +95,8 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 2604e226b81..11b8e157e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -95,8 +95,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index c43be4f2821..78324acf8a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -76,6 +76,7 @@ public Poisson(Ops tf) { * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Poisson(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -95,7 +96,7 @@ public Poisson(Ops tf, Reduction reduction) { * Creates a Poisson Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Poisson(Ops tf, String name, Reduction reduction) { @@ -104,8 +105,8 @@ public Poisson(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.poisson(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index ea765e6f8fd..d04cc67d5d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -79,7 +80,8 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ @@ -88,8 +90,8 @@ public SparseCategoricalCrossentropy(Ops tf) { } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -122,8 +124,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and - * fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -135,7 +137,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -176,9 +179,10 @@ public SparseCategoricalCrossentropy( /** * Generates an Operand the calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. @@ -190,23 +194,22 @@ public SparseCategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesHelper.rangeCheck( - getTF(), - "predictions range check [0-1]", - predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + LossesHelper.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 4ad4c1c726c..dadbdb3b95e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -25,8 +26,8 @@ * *

loss = square(maximum(1 - labels * predictions, 0)) * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be - * converted to -1 or 1. + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -107,7 +108,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor @@ -117,21 +118,23 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 2104937a979..f811549fbca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return labels, target, and sampleWeights + * A helper class for loss methods to return labels, target, and sampleWeights * * @param the data type of the LossTuple entries. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 10067db91ba..f6b0de71b0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -32,8 +32,9 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to - * TensorFlow Java. These methods should not be used outside of the losses and metrics packages. + * These are helper methods for Losses and Metrics and will be module private when Java modularity + * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics + * packages. */ public class LossesHelper { @@ -42,16 +43,17 @@ public class LossesHelper { * *

    *
  1. Squeezes last dim of predictions or labels if their rank - * differs by 1 (using {@link #removeSqueezableDimensions}).
  2. + * differs by 1 (using {@link #removeSqueezableDimensions}). *
  3. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from * the new rank of predictions. If sampleWeight is scalar, it is - * kept scalar.
  4. + * kept scalar. *
* * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . + * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, @@ -77,12 +79,14 @@ public static LossTuple squeezeOrExpandDimensions( * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * * prediction. - * @return LossTuple of predictions, labels and sampleWeight. - * Each of them possibly has the last dimension squeezed, sampleWeight could be - * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are - * returned. + * @param the data type for the labels, predictions and result + * @return LossTuple of predictions, labels and sampleWeight + * . Each of them possibly has the last dimension squeezed, sampleWeight + * could be extended by one dimension. If sampleWeight is null, only the possibly + * shape modified predictions and labels are returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -178,6 +182,7 @@ private static Operand maybeExpandWeights( * @param labels Label values, a Tensor whose dimensions match predictions * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -193,6 +198,7 @@ public static LossTuple removeSqueezableDimensions( * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -216,7 +222,8 @@ public static LossTuple removeSqueezableDimensions( } // Use dynamic rank. - // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze @@ -298,8 +305,7 @@ private static Operand reduceWeightedLoss( public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.type())); + return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); } /** @@ -383,8 +389,7 @@ public static Operand rangeCheck( */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { - Operand flatValues = - tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); + Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().shape().size(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 651a6fac0b0..48ee244eafb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,17 +21,18 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * *

This is the crossentropy metric class to be used when there are only two label classes (0 and * 1). * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -41,7 +42,8 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing @@ -60,7 +62,10 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index c330ea88eaa..b22e5415f79 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. @@ -30,11 +32,10 @@ * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class CategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -48,7 +49,8 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to + * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +70,8 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -98,8 +101,11 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( - getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 2741a36edb6..4266cc487c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,13 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper +public class CategoricalHinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +47,10 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.categoricalHinge(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.categoricalHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 458de092bec..840f255c5ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,18 +15,20 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the cosine similarity metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; @@ -76,8 +78,12 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Metrics.cosineProximity(getTF(), labels, predictions, axis); + public Operand call( + Operand labels, Operand predictions) { + // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, + // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index baf9ad8ab7d..46ccd2859ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper - implements LossMetric { +public class Hinge extends MeanMetricWrapper implements LossMetric { /** * Creates a Hinge metric @@ -46,7 +46,10 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.hinge(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.hinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index efcbbcbb7f0..9ffcd6189f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,15 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper - implements LossMetric { +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** * Creates a KLDivergence metric @@ -47,7 +47,10 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 3df8505d54b..59e24f57110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -21,15 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper - implements LossMetric { +public class LogCoshError extends MeanMetricWrapper implements LossMetric { /** * Creates a LogCoshError metric @@ -47,7 +47,10 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.logCosh(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.logCosh(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index de1f5a5629e..8902b329bcc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -21,10 +21,9 @@ /** * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * - * @param The data type for the metric values * @param The data type for the metric result */ -public class Mean extends Reduce { +public class Mean extends Reduce { /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index e27676932ff..1cc6d0b6f99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,13 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper +public class MeanAbsoluteError extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +47,10 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsoluteError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 84fa9b627b2..8c6720b58f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,14 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +47,10 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c7edd6ebe93..3c4c79d39ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,13 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper +public class MeanSquaredError extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +47,10 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 199b6e0e114..d525bb76648 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,14 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +47,10 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index bbb2aa73da2..468919e696d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -25,10 +25,9 @@ /** * Base class for Metrics * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Metric { +public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; @@ -75,10 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state - * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -90,13 +89,13 @@ public List updateStateList(Operand values, Operand the data type for the labels - * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -105,10 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -119,12 +118,12 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the labels - * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, + Operand predictions, + Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -149,10 +148,9 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies - * @param the data type for the sampleWeights. */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); @@ -186,7 +184,11 @@ public String getName() { return name; } - /** The random number generator seed value */ + /** + * Gets the random number generator seed value + * + * @return the random number generator seed value + */ public long getSeed() { return seed; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 0169bc6b8bc..95b74bf1eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -17,7 +17,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; @@ -46,89 +45,14 @@ public class Metrics { * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. * @param the data type for the predictions and results - * @param the data type ofr the labels. * @return the Operand for the Top K categorical accuracy value. */ - public static Operand topKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, long k) { + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); return CastHelper.cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } - - /** - * Computes the cosine similarity between labels and predictions. - * - * @param tf the TensorFlow Ops - * @param labels The ground truth values. - * @param predictions The prediction values. - * @param axes The dimensions along which the cosine similarity is computed. - * @param the data type for the labels - * @param the data type for the predictions and result - * @return Cosine similarity value. - */ - public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axes) { - Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axes); - - Operand predictionsNorm = l2Normalize(tf, predictions, axes); - Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); - } - - /** - * Normalizes along dimension axis using an L2 norm with an epsilon of {@link - * #L2_NORM_EPSILON}. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param The data type for x. - * @return the normalized values of x. - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { - return l2Normalize(tf, x, axes, L2_NORM_EPSILON); - } - - /** - * Normalizes along dimension axis using an L2 norm. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the - * divisor if norm < sqrt(epsilon). - * @param The data type for the values. - * @return the normalized values of x. - */ - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes, float epsilon) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); - Operand y = - tf.math.rsqrt( - tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); - return tf.math.mul(x, y); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 75a2031fbb5..422fd4808ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper - implements LossMetric { +public class Poisson extends MeanMetricWrapper implements LossMetric { /** * Creates a Poisson metric @@ -46,7 +46,10 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.poisson(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.poisson(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 2e01f722de6..e954169b2af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,15 +21,16 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; @@ -39,7 +40,8 @@ public class SparseCategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -55,7 +57,10 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 430dbbcc229..19b3b1d0ac4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the squared hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** * Creates a SquaredHinge metric @@ -46,7 +46,10 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.squaredHinge(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.squaredHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index b7b87d313aa..1fb3d3bb580 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -29,8 +29,7 @@ public interface LossMetric { * * @param labels the truth values or labels * @param predictions the predictions - * @param The data type of the labels. * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 17c209a8fed..9a532a0294f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -17,13 +17,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.metrics.Mean; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. @@ -32,10 +33,9 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class MeanMetricWrapper extends Mean { +public class MeanMetricWrapper extends Mean { /** The loss function interface */ protected LossMetric loss; @@ -85,22 +85,21 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList( - CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList(cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index ad8ff58e417..8a352322f52 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -21,12 +21,10 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; @@ -57,13 +55,13 @@ public class MetricsHelper { * @param values the values to which weights are applied. * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values - * @param the type of Operand + * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") - public static Op assertBroadcastable( - Ops tf, Operand sampleWeights, Operand values) { + public static Op assertBroadcastable( + Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match @@ -129,7 +127,7 @@ public static Op assertBroadcastable( // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. - Operand reshapedWeights = + Operand reshapedWeights = tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); weightsShape = tf.shape(reshapedWeights); weightsRank = tf.rank(reshapedWeights); @@ -237,11 +235,10 @@ public static Operand mean(Ops tf, Operand x) { * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axes. * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axes) { + public static Operand mean( + Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } @@ -257,31 +254,27 @@ public static Operand mean( * @param the type of the operand * @return the mean of elements of x. */ - public static Operand mean( - Ops tf, Operand x, boolean keepDims) { + public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } - - /** * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @param the data type of the Operand - * @param the data type of the axes * @return the mean of elements of x. */ - - public static Operand mean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { if (axes == null) { - axes = (Operand) allAxes(tf, x); + axes = allAxes(tf, x); } return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } @@ -294,7 +287,7 @@ public static Operand mean( * @param x the Operand used to calculate the mean * @return the mean of the operand containing floating point numbers */ - public static Operand booleanMean(Ops tf, Operand x) { + public static Operand booleanMean(Ops tf, Operand x) { return booleanMean(tf, x, null, false); } @@ -305,44 +298,43 @@ public static Operand booleanMean(Ops tf, Operand x) { * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param the type of the axes. * @return the mean of the operand, along the specified axes, containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x,Operand axes) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes) { return booleanMean(tf, x, axes, false); } /** * Calculates the mean of the boolean operand, alongside all axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, boolean keepDims) { + public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); } /** * Calculates the mean of the boolean operand, alongside the specified axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { Operand xf = cast(tf, x, TFloat64.class); return mean(tf, xf, axes, keepDims); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 8e48cb4e573..2a26967b9f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -19,7 +19,6 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -29,13 +28,14 @@ import java.util.ArrayList; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Encapsulates metrics that perform a reduce operation on the metric values. * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Reduce extends Metric { +public abstract class Reduce extends Metric { public static final String TOTAL = "total"; public static final String COUNT = "count"; protected final MetricReduction reduction; @@ -45,8 +45,10 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values. - * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + /** + * the variable that holds the count of the metric values. For {@link + * MetricReduction#WEIGHTED_MEAN}, this count may be weighted + */ protected Variable count; /** @@ -95,12 +97,10 @@ private void setupVars() { public Op resetStates() { List controls = new ArrayList<>(); if (total != null) { - controls.add( - getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + controls.add(getTF().assign(total, cast(getTF(), getTF().constant(0), total.type()))); } if (count != null) { - controls.add( - getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + controls.add(getTF().assign(count, cast(getTF(), getTF().constant(0), count.type()))); } return getTF().withControlDependencies(controls).noOp(); } @@ -115,67 +115,67 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); } + Ops tf = getTF(); List updateOperations = new ArrayList<>(); // cast everything to match the variables - Operand lSampleWeights = null; - Operand lValues = values; + Operand tSampleWeights = null; + Operand tValues = cast(tf, values, getResultType()); if (sampleWeights != null) { - lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); - lValues = tuple.getTarget(); - lSampleWeights = tuple.getSampleWeights(); + tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); try { - lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples - int valuesRank = lValues.shape().numDimensions(); - int weightsRank = lSampleWeights.shape().numDimensions(); + int valuesRank = tValues.shape().numDimensions(); + int weightsRank = tSampleWeights.shape().numDimensions(); int numAxes = Math.min(0, valuesRank - weightsRank); if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + tValues = getTF().reduceSum(tValues, getTF().constant(axes)); } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + tValues = getTF().math.mean(tValues, getTF().constant(axes)); } } } - lValues = getTF().math.mul(lValues, lSampleWeights); + tValues = getTF().math.mul(tValues, tSampleWeights); } - Operand weightedValueSum = - getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); break; case WEIGHTED_MEAN: - if (lSampleWeights == null) { - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + if (tSampleWeights == null) { + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); } else { numValues = - CastHelper.cast( + cast( getTF(), getTF() - .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), resultType); } break; @@ -202,7 +202,7 @@ public Operand result() { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + fResult = getTF().math.divNoNan(total, cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 1841c7ee238..467dea19b57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -25,33 +25,6 @@ /** Implementation of set operations */ public class SetsOps { - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } - /** * Computes set difference of elements in last dimension of a and b with * aMinusB set to true. @@ -69,6 +42,7 @@ public String getSetOperation() { public static Operand difference(Ops tf, Operand a, Operand b) { return difference(tf, a, b, true); } + /** * Computes set difference of elements in last dimension of a and b. * @@ -143,4 +117,31 @@ public static Operand setOperation( setOperationResult.resultValues(), cast(tf, tf.constant(0), a.type())); } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 822eb490f22..aadbfeea54b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -31,29 +31,29 @@ * learning rate per dimension to address two drawbacks: * *

    - *
  • the continual decay of learning rates throughout training - *
  • the need for a manually selected global learning rate + *
  • the continual decay of learning rates throughout training + *
  • the need for a manually selected global learning rate *
* - *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a - * moving window of gradient updates, instead of accumulating all past gradients. This way, - * Adadelta continues learning even when many updates have been done. Compared to Adagrad, in - * the original version of Adadelta you don't have to set an initial learning rate. In this - * version, initial learning rate can be set, as in most other optimizers. + *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving + * window of gradient updates, instead of accumulating all past gradients. This way, Adadelta + * continues learning even when many updates have been done. Compared to Adagrad, in the original + * version of Adadelta you don't have to set an initial learning rate. In this version, initial + * learning rate can be set, as in most other optimizers. * - *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes - * converge to 1 which is effectively a high learning rate which would cause divergence. This - * occurs only near the end of the training as gradients and step sizes are small, and the - * epsilon constant in the numerator and denominator dominate past gradients and parameter - * updates which converge the learning rate to 1. + *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes + * converge to 1 which is effectively a high learning rate which would cause divergence. This occurs + * only near the end of the training as gradients and step sizes are small, and the epsilon constant + * in the numerator and denominator dominate past gradients and parameter updates which converge the + * learning rate to 1. * - *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers - * was trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The - * epsilon used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following - * construction: new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); + *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers was + * trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The epsilon + * used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following construction: + * new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); * * @see Zeiler, M., 2012 ADADELTA: An Adaptive Learning - * Rate Method. + * Rate Method */ public class AdaDelta extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 08f5f18a9cd..2dd05ef31b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -31,10 +31,10 @@ * how frequently a parameter gets updated during training. The more updates a parameter receives, * the smaller the updates. * - *

- * - * @see Duchi, J, et al., 2011, Adaptive Subgradient Methods for Online Learning and Stochastic Optimization - * @see Duchi, J, et al., 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1. + * @see Duchi, J, et al., 2011, + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization + * @see Duchi, J, et al., + * 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1 */ public class AdaGrad extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index df624e41c4e..7114c33339f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -40,7 +40,7 @@ * networks as it will require careful initialization of the gradient accumulators for it to train. * * @see Duchi, J, et al., 2011, - * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization */ public class AdaGradDA extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index cd95bb3bd07..0ecc1ac1451 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -32,12 +32,10 @@ public class Adamax extends Optimizer { public static final float EPSILON_DEFAULT = 1e-07f; public static final float BETA_ONE_DEFAULT = 0.9f; public static final float BETA_TWO_DEFAULT = 0.999f; - - private float learningRate; private final float betaOne; private final float betaTwo; private final float epsilon; - + private final float learningRate; private Constant learningRateConst; private Constant epsilonConst; private Constant betaOneConst; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 66314d2ffe0..5d8c1478231 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -13,10 +13,11 @@ /** * Optimizer that implements the FTRL algorithm. * + *

This version has support for both online L2 (the L2 penalty given in the paper below) and + * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * * @see McMahan, et - * al., 2013, Algorithm 1 - *

This version has support for both online L2 (the L2 penalty given in the paper above) and - * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * al., 2013, Algorithm 1 */ public class Ftrl extends Optimizer { @@ -29,13 +30,12 @@ public class Ftrl extends Optimizer { public static final float L1STRENGTH_DEFAULT = 0.0f; public static final float L2STRENGTH_DEFAULT = 0.0f; public static final float L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT = 0.0f; - - private float learningRate; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; private final float l2RegularizationStrength; private final float l2ShrinkageRegularizationStrength; + private final float learningRate; /** * Creates a Ftrl Optimizer diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index f9900a8ee78..5b94b548c0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -24,8 +24,6 @@ */ public class Nadam extends Optimizer { - private static final float DECAY_BASE = 0.96f; - private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float EPSILON_DEFAULT = 1e-8f; public static final float BETA_ONE_DEFAULT = 0.9f; @@ -33,7 +31,8 @@ public class Nadam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; public static final String MOMENTUM = "momentum"; - + private static final float DECAY_BASE = 0.96f; + private static final float DECAY = 0.004f; /** The learning rate. */ private final float learningRate; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index fdf56da4a67..ed141831bbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -71,14 +71,6 @@ protected Optimizer(Graph graph, String name) { this.globals = new ArrayList<>(); } - /** - * Gets the Optimizer's Ops instance - * @return the Optimizer's Ops instance - */ - public final Ops getTF() { - return tf; - } - /** * Creates a name by combining a variable name and a slot name * @@ -90,6 +82,15 @@ public static String createName(Output variable, String slotNam return variable.op().name() + "-" + slotName; } + /** + * Gets the Optimizer's Ops instance + * + * @return the Optimizer's Ops instance + */ + public final Ops getTF() { + return tf; + } + /** * Minimizes the loss by updating the variables * @@ -299,7 +300,8 @@ private Options() {} * Sets the shared name * * @param sharedName If non-empty, this variable is named in the given bucket with this - * shared_name. Otherwise, the node name is used instead. + * sharedName. Otherwise, the node name is used instead. + * @return this options instance */ public Optimizer.Options sharedName(String sharedName) { this.sharedName = sharedName; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index b3729dc367f..e86e64971a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -27,17 +27,20 @@ /** * Optimizer that implements the RMSProp algorithm. * - *

The gist of RMSprop is to:

    - *
  • Maintain a moving (discounted) average of the square of gradients - *
  • Divide the gradient by the root of this average
+ *

The gist of RMSprop is to: * - *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + *

    + *
  • Maintain a moving (discounted) average of the square of gradients + *
  • Divide the gradient by the root of this average + *
* - *

The centered version additionally maintains a moving average of the gradients, and uses - * that average to estimate the variance. + *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + * + *

The centered version additionally maintains a moving average of the gradients, and uses that + * average to estimate the variance. * * @see Hinton G, - * et al. 2012, lecture notes that is inexplicably the canonical reference. + * et al. 2012, lecture notes, that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { @@ -165,24 +168,20 @@ protected void createSlots(List> variables) { } } - /** - * Creates the RMSProp Slots for Root Mean Squared (RMS), - * MOMENTUM, and Mean Gradient (MG) + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) * * @param v the variable to install in the slot * @param the datatype of the variable. */ private void createRMSPropSlot(Output v) { - Operand rmsInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); + Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); + Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MG, mgInitializer); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java index b0fe48967dd..1c027cb5ddf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -34,7 +34,7 @@ public class CastHelper { */ @SuppressWarnings("unchecked") public static Operand cast( - Ops tf, Operand value, Class requiredType) { + Ops tf, Operand value, Class requiredType) { return (value.type() == requiredType) ? (Operand) value : tf.dtypes.cast(value, requiredType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 4ca2c789f28..e730c79cfbf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -14,8 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; -import org.tensorflow.ndarray.NdArray; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; @@ -33,7 +34,9 @@ public class ShapeUtils { /** * Converts a shape operand to a Shape object * + * @param scope the TensorFlow scope * @param dims the Operand containing the shape values + * @param the date type for the shape dimensions. * @return a new Shape based on an Operand that contains dimensions */ public static Shape toShape(Scope scope, Operand dims) { @@ -45,8 +48,8 @@ public static Shape toShape(Scope scope, Operand dims) * Converts a TInt32 type Operand to a Java int array * * @param scope the TensorFlow scope - * @param dims the TInt32 Operand - * @return the int array + * @param dims the shape dimensions operand + * @return the int array of the dimensions */ public static int[] getIntArray(Scope scope, Operand dims) { long[] longDims = getLongArray(scope, dims); @@ -66,8 +69,8 @@ public static long[] getLongArray(Scope scope, Operand if (scope.env().isEager()) { return getLongArray(dims.asTensor()); } - try (Session session = new Session((Graph)scope.env()); - TIntegral tensor = (TIntegral)session.runner().fetch(dims).run().get(0)) { + try (Session session = new Session((Graph) scope.env()); + TIntegral tensor = (TIntegral) session.runner().fetch(dims).run().get(0)) { return getLongArray(tensor); } } @@ -76,20 +79,21 @@ public static long[] getLongArray(Scope scope, Operand * Converts a TInt32 or TInt64 to a java long array * * @param dims the dimension tensor + * @param the type of the dimensions, must either be TInt32 or TInt64 type * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(T dims) { List result = new ArrayList<>(); if (dims instanceof TInt32) { - ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); + ((TInt32) dims).scalars().forEach(s -> result.add((long) s.getInt())); } else if (dims instanceof TInt64) { - ((TInt64)dims).scalars().forEach(s -> result.add(s.getLong())); + ((TInt64) dims).scalars().forEach(s -> result.add(s.getLong())); } else if (dims instanceof TUint8) { - ((TUint8)dims).scalars().forEach(s -> result.add(s.getObject().longValue())); - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } + ((TUint8) dims).scalars().forEach(s -> result.add(s.getObject().longValue())); + } else { // shouldn't happen + throw new IllegalArgumentException("the data type must be an integer type"); + } return result.stream().mapToLong(i -> i).toArray(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 7ceedded018..be46bb5c282 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -32,7 +32,7 @@ class BinaryCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -77,7 +77,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 0, 1, 0}; @@ -102,7 +102,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -128,7 +128,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java index 2b4a1d75467..34fc3eef884 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -31,7 +31,7 @@ class CategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -104,7 +104,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 0, 0, 1}; @@ -129,7 +129,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java index 87248d95e48..78b25a21b60 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -31,7 +31,7 @@ class CategoricalHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { @@ -64,7 +64,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index a9721ef2f8f..18410416c42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -31,7 +31,7 @@ class CosineSimilarityTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -80,7 +80,7 @@ public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); int axis = 1; - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 6af5fed4889..90531d21fde 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,8 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -55,8 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 28020c0fa1c..267578a492c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -31,7 +31,7 @@ class KLDivergenceTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java index 31c043e0473..1b5b8fb7d49 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -32,7 +32,7 @@ class LogCoshErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -56,7 +56,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java index 73241ecbe9f..984895f2ad9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -32,7 +32,7 @@ class MeanAbsoluteErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -74,7 +74,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java index 4c92844b217..0b9e7f6b538 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -34,7 +34,7 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -76,7 +76,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java index 0b760213015..e42052a9ef1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -33,7 +33,7 @@ class MeanSquaredErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); @@ -70,7 +70,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java index 098a5cb9725..e68d63b8778 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -32,7 +32,7 @@ class MeanSquaredLogarithmicErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -69,7 +69,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index cf3c3e44719..5631bac15ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -32,7 +32,7 @@ class PoissonTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 87af1bd8448..0aece8c8ac9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -32,7 +32,7 @@ class SparseCategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -56,7 +56,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); session.run(instance.resetStates()); @@ -105,7 +105,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java index e3376c224f3..2c80b3451ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -32,7 +32,7 @@ class SquaredHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { @@ -61,7 +61,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tools/build_java_api_docs.py b/tools/build_java_api_docs.py new file mode 100644 index 00000000000..42a83703ecf --- /dev/null +++ b/tools/build_java_api_docs.py @@ -0,0 +1,83 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generate TensorFlow Lite Java reference docs for TensorFlow.org.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pathlib +import shutil +import tempfile + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import gen_java + +FLAGS = flags.FLAGS + +# These flags are required by infrastructure, not all of them are used. +flags.DEFINE_string('output_dir', '/tmp/java_api/', + ("Use this branch as the root version and don't" + ' create in version directory')) + +flags.DEFINE_string('site_path', 'java/api_docs/java', + 'Path prefix in the _toc.yaml') + +flags.DEFINE_string('code_url_prefix', None, + '[UNUSED] The url prefix for links to code.') + +flags.DEFINE_bool( + 'search_hints', True, + '[UNUSED] Include metadata search hints in the generated files') + +# __file__ is the path to this file +TOOLS_DIR = pathlib.Path(__file__).resolve().parent +REPO_ROOT = TOOLS_DIR.parent + +def overlay(from_root, to_root): + for from_path in pathlib.Path(from_root).rglob('*'): + relpath = from_path.relative_to(from_root) + to_path = to_root/relpath + if from_path.is_file(): + assert not to_path.exists() + shutil.copyfile(from_path, to_path) + else: + to_path.mkdir(exist_ok=True) + +def main(unused_argv): + merged_source = pathlib.Path(tempfile.mkdtemp()) + (merged_source / 'java/org').mkdir(parents=True) + + shutil.copytree(REPO_ROOT/'tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/', + merged_source/'java/org/tensorflow') + overlay(REPO_ROOT/'tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow', + merged_source/'java/org/tensorflow') + shutil.copytree(REPO_ROOT/'tensorflow-framework/src/main/java/org/tensorflow/framework', + merged_source/'java/org/tensorflow/framework') + shutil.copytree(REPO_ROOT/'ndarray/src/main/java/org/tensorflow/ndarray', + merged_source/'java/org/tensorflow/ndarray') + + gen_java.gen_java_docs( + package='org.tensorflow', + source_path=merged_source / 'java', + output_dir=pathlib.Path(FLAGS.output_dir), + site_path=pathlib.Path(FLAGS.site_path)) + + +if __name__ == '__main__': + flags.mark_flags_as_required(['output_dir']) + app.run(main)