diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs index 1dafc87ea..6f6fd0b58 100644 --- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -13,7 +13,7 @@ public class OwnedIterator : IDisposable IDatasetV2 _dataset; TensorSpec[] _element_spec; dataset_ops ops = new dataset_ops(); - Tensor _deleter; + //Tensor _deleter; Tensor _iterator_resource; public OwnedIterator(IDatasetV2 dataset) @@ -26,7 +26,6 @@ void _create_iterator(IDatasetV2 dataset) dataset = dataset.apply_options(); _dataset = dataset; _element_spec = dataset.element_spec; - // _flat_output_types = _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes); // TODO(Rinne): deal with graph mode. ops.make_iterator(dataset.variant_tensor, _iterator_resource); diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index e02642dcf..a462a68eb 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -62,7 +62,7 @@ void evaluate(NDArray x, NDArray y, bool use_multiprocessing = false, bool return_dict = false); - Tensors predict(Tensor x, + Tensors predict(Tensors x, int batch_size = -1, int verbose = 0, int steps = -1, diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs index c27ea9090..984bcb5dc 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -49,7 +49,7 @@ public Tensors predict(IDatasetV2 dataset, /// /// /// - public Tensors predict(Tensor x, + public Tensors predict(Tensors x, int batch_size = -1, int verbose = 0, int steps = -1, @@ -115,12 +115,12 @@ Tensors PredictInternal(DataHandler data_handler, int verbose) Tensors run_predict_step(OwnedIterator iterator) { var data = iterator.next(); - var outputs = predict_step(data[0]); + var outputs = predict_step(data); tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1)); return outputs; } - Tensors predict_step(Tensor data) + Tensors predict_step(Tensors data) { return Apply(data, training: false); } diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs index 490178bc9..a762a1c65 100644 --- a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs @@ -1,27 +1,17 @@ -using Microsoft.VisualStudio.TestPlatform.Utilities; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using System.Xml.Linq; -using Tensorflow.Operations; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; -using Tensorflow.NumPy; -using Microsoft.VisualBasic; -using static HDF.PInvoke.H5T; -using Tensorflow.Keras.UnitTest.Helpers; +using Tensorflow; using Tensorflow.Keras.Optimizers; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; -namespace Tensorflow.Keras.UnitTest +namespace TensorFlowNET.Keras.UnitTest { [TestClass] public class MultiInputModelTest { [TestMethod] - public void SimpleModel() + public void LeNetModel() { var inputs = keras.Input((28, 28, 1)); var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); @@ -40,7 +30,7 @@ public void SimpleModel() var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); - var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); + var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); var output = keras.layers.Softmax(-1).Apply(dense3); var model = keras.Model((inputs, inputs_2), output); @@ -52,7 +42,7 @@ public void SimpleModel() { TrainDir = "mnist", OneHot = false, - ValidationSize = 59000, + ValidationSize = 59900, }).Result; var loss = keras.losses.SparseCategoricalCrossentropy(); @@ -64,6 +54,11 @@ public void SimpleModel() var x = new NDArray[] { x1, x2 }; model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3); + + x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT); + x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT); + var pred = model.predict((x1, x2)); + Console.WriteLine(pred); } } }