Skip to content

Support mutiple inputs of keras modek.predict. #1000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/TensorFlowNET.Core/Data/OwnedIterator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class OwnedIterator : IDisposable
IDatasetV2 _dataset;
TensorSpec[] _element_spec;
dataset_ops ops = new dataset_ops();
Tensor _deleter;
//Tensor _deleter;
Tensor _iterator_resource;

public OwnedIterator(IDatasetV2 dataset)
Expand All @@ -26,7 +26,6 @@ void _create_iterator(IDatasetV2 dataset)
dataset = dataset.apply_options();
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
// TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void evaluate(NDArray x, NDArray y,
bool use_multiprocessing = false,
bool return_dict = false);

Tensors predict(Tensor x,
Tensors predict(Tensors x,
int batch_size = -1,
int verbose = 0,
int steps = -1,
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Keras/Engine/Model.Predict.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public Tensors predict(IDatasetV2 dataset,
/// <param name="workers"></param>
/// <param name="use_multiprocessing"></param>
/// <returns></returns>
public Tensors predict(Tensor x,
public Tensors predict(Tensors x,
int batch_size = -1,
int verbose = 0,
int steps = -1,
Expand Down Expand Up @@ -115,12 +115,12 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
Tensors run_predict_step(OwnedIterator iterator)
{
var data = iterator.next();
var outputs = predict_step(data[0]);
var outputs = predict_step(data);
tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1));
return outputs;
}

Tensors predict_step(Tensor data)
Tensors predict_step(Tensors data)
{
return Apply(data, training: false);
}
Expand Down
31 changes: 13 additions & 18 deletions test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
using Microsoft.VisualBasic;
using static HDF.PInvoke.H5T;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.UnitTest
namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class MultiInputModelTest
{
[TestMethod]
public void SimpleModel()
public void LeNetModel()
{
var inputs = keras.Input((28, 28, 1));
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
Expand All @@ -40,7 +30,7 @@ public void SimpleModel()
var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var output = keras.layers.Softmax(-1).Apply(dense3);

var model = keras.Model((inputs, inputs_2), output);
Expand All @@ -52,7 +42,7 @@ public void SimpleModel()
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 59000,
ValidationSize = 59900,
}).Result;

var loss = keras.losses.SparseCategoricalCrossentropy();
Expand All @@ -64,6 +54,11 @@ public void SimpleModel()

var x = new NDArray[] { x1, x2 };
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);

x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
var pred = model.predict((x1, x2));
Console.WriteLine(pred);
}
}
}