diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
index 1dafc87ea..6f6fd0b58 100644
--- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs
+++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
@@ -13,7 +13,7 @@ public class OwnedIterator : IDisposable
IDatasetV2 _dataset;
TensorSpec[] _element_spec;
dataset_ops ops = new dataset_ops();
- Tensor _deleter;
+ //Tensor _deleter;
Tensor _iterator_resource;
public OwnedIterator(IDatasetV2 dataset)
@@ -26,7 +26,6 @@ void _create_iterator(IDatasetV2 dataset)
dataset = dataset.apply_options();
_dataset = dataset;
_element_spec = dataset.element_spec;
- // _flat_output_types =
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
// TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
index e02642dcf..a462a68eb 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
@@ -62,7 +62,7 @@ void evaluate(NDArray x, NDArray y,
bool use_multiprocessing = false,
bool return_dict = false);
- Tensors predict(Tensor x,
+ Tensors predict(Tensors x,
int batch_size = -1,
int verbose = 0,
int steps = -1,
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs
index c27ea9090..984bcb5dc 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs
@@ -49,7 +49,7 @@ public Tensors predict(IDatasetV2 dataset,
///
///
///
- public Tensors predict(Tensor x,
+ public Tensors predict(Tensors x,
int batch_size = -1,
int verbose = 0,
int steps = -1,
@@ -115,12 +115,12 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
Tensors run_predict_step(OwnedIterator iterator)
{
var data = iterator.next();
- var outputs = predict_step(data[0]);
+ var outputs = predict_step(data);
tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1));
return outputs;
}
- Tensors predict_step(Tensor data)
+ Tensors predict_step(Tensors data)
{
return Apply(data, training: false);
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
index 490178bc9..a762a1c65 100644
--- a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
@@ -1,27 +1,17 @@
-using Microsoft.VisualStudio.TestPlatform.Utilities;
-using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
-using System.Xml.Linq;
-using Tensorflow.Operations;
-using static Tensorflow.Binding;
-using static Tensorflow.KerasApi;
-using Tensorflow.NumPy;
-using Microsoft.VisualBasic;
-using static HDF.PInvoke.H5T;
-using Tensorflow.Keras.UnitTest.Helpers;
+using Tensorflow;
using Tensorflow.Keras.Optimizers;
+using Tensorflow.NumPy;
+using static Tensorflow.KerasApi;
-namespace Tensorflow.Keras.UnitTest
+namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class MultiInputModelTest
{
[TestMethod]
- public void SimpleModel()
+ public void LeNetModel()
{
var inputs = keras.Input((28, 28, 1));
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
@@ -40,7 +30,7 @@ public void SimpleModel()
var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
- var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
+ var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var output = keras.layers.Softmax(-1).Apply(dense3);
var model = keras.Model((inputs, inputs_2), output);
@@ -52,7 +42,7 @@ public void SimpleModel()
{
TrainDir = "mnist",
OneHot = false,
- ValidationSize = 59000,
+ ValidationSize = 59900,
}).Result;
var loss = keras.losses.SparseCategoricalCrossentropy();
@@ -64,6 +54,11 @@ public void SimpleModel()
var x = new NDArray[] { x1, x2 };
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
+
+ x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
+ x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
+ var pred = model.predict((x1, x2));
+ Console.WriteLine(pred);
}
}
}