Description
System information
- OS: macOS Big Sur 11.3.1
- TensorFlow-java version: 0.3.1
- Java version: I tried both 1.8 and 11
Issue
I am able to load this biggan-deep model into java with the SavedModelBundle.Loader but when I run it, I get
TFFailedPreconditionException: Error while reading resource variable prev_truncation from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/prev_truncation)
The model works in Python. I initially encountered the issue when I used DJL to load and run the model. I tried it with TF-Java to see if the issue is in DJL or TF-Java.
Describe the expected behavior
The model should run fine just like in python.
Java code to reproduce the issue
The main method and the loadModel(...)
are the important ones, the rest are just tensor operations.
public static void main(String[] args) {
SavedModelBundle model = loadModel("src/test/resources/biggan-deep-128_1", new String[0]);
int[] input = {100, 207, 971, 970, 933}; // image classes
Tensor y = oneHot(input, 1000);
Tensor z = truncNormal(input.length, 128);
Tensor result =
model.session().runner()
.feed("y", y)
.feed("z", z)
.feed("truncation", TFloat32.scalarOf(0.5f))
.fetch("G_trunc_output").run().get(0);
}
private static SavedModelBundle loadModel(String dir, String[] tags) {
SavedModelBundle.Loader loader = SavedModelBundle.loader(dir);
try {
Field field = SavedModelBundle.Loader.class.getDeclaredField("tags");
field.setAccessible(true);
field.set(loader, tags);
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
}
return loader.load();
}
private static Tensor truncNormal(int row, int col) {
float[][] dist = new float[row][col];
Random random = new Random();
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
double sample = random.nextGaussian();
while (sample < -2 || sample > 2) {
sample = random.nextGaussian();
}
dist[i][j] = (float) sample;
}
}
return toTensor(Shape.of(row, col), dist);
}
private static Tensor oneHot(int[] input, int numCategories) {
float[][] dist = new float[input.length][numCategories];
for (int i = 0; i < input.length; i++) {
float[] row = new float[numCategories];
Arrays.fill(row, 0.0f);
row[input[i]] = 1.0f;
dist[i] = row;
}
return toTensor(Shape.of(input.length, numCategories), dist);
}
private static TFloat32 toTensor(Shape shape, float[][] data) {
FloatNdArray mat = NdArrays.ofFloats(shape);
for (int i = 0; i < data.length; i++) {
mat.set(TFloat32.vectorOf(data[i]), i);
}
return TFloat32.tensorOf(mat);
}
build.gradle:
dependencies {
compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.1'
}
Other info / logs
Console output:
> Task :Main.main()
2021-08-04 22:07:22.105450: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:32] Reading SavedModel from: src/test/resources/biggan-deep-128_1
2021-08-04 22:07:22.295033: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:55] Reading meta graph with tags { }
2021-08-04 22:07:22.295055: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:93] Reading SavedModel debug info (if present) from: src/test/resources/biggan-deep-128_1
2021-08-04 22:07:22.295136: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:142]
This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-04 22:07:23.514662: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { }; Status: success: OK. Took 1409219 microseconds.
Exception in thread "main" org.tensorflow.exceptions.TFFailedPreconditionException: Error while reading resource variable prev_truncation from Container: localhost.
This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/prev_truncation)
[[{{node Equal/ReadVariableOp}}]]
at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:95)
at org.tensorflow.Session.run(Session.java:691)
at org.tensorflow.Session.access$100(Session.java:72)
at org.tensorflow.Session$Runner.runHelper(Session.java:381)
at org.tensorflow.Session$Runner.run(Session.java:329)
at ai.tf.Main.main(Main.java:32)
> Task :Main.main() FAILED
Execution failed for task ':Main.main()'.
> Process 'command '/Library/Java/JavaVirtualMachines/amazon-corretto-8.jdk/Contents/Home/bin/java'' finished with non-zero exit value 1
* Try:
Run with --stacktrace option to get the stack trace. Run with --info or --debug option to get more log output. Run with --scan to get full insights.