Skip to content

Commit 3db092b

Browse files
AsakusaRinneOceania2018
authored andcommitted
Support mutiple inputs of keras modek.predict.
1 parent 8550dcc commit 3db092b

File tree

4 files changed

+18
-24
lines changed

4 files changed

+18
-24
lines changed

src/TensorFlowNET.Core/Data/OwnedIterator.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class OwnedIterator : IDisposable
1313
IDatasetV2 _dataset;
1414
TensorSpec[] _element_spec;
1515
dataset_ops ops = new dataset_ops();
16-
Tensor _deleter;
16+
//Tensor _deleter;
1717
Tensor _iterator_resource;
1818

1919
public OwnedIterator(IDatasetV2 dataset)
@@ -26,7 +26,6 @@ void _create_iterator(IDatasetV2 dataset)
2626
dataset = dataset.apply_options();
2727
_dataset = dataset;
2828
_element_spec = dataset.element_spec;
29-
// _flat_output_types =
3029
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
3130
// TODO(Rinne): deal with graph mode.
3231
ops.make_iterator(dataset.variant_tensor, _iterator_resource);

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void evaluate(NDArray x, NDArray y,
6262
bool use_multiprocessing = false,
6363
bool return_dict = false);
6464

65-
Tensors predict(Tensor x,
65+
Tensors predict(Tensors x,
6666
int batch_size = -1,
6767
int verbose = 0,
6868
int steps = -1,

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public Tensors predict(IDatasetV2 dataset,
4949
/// <param name="workers"></param>
5050
/// <param name="use_multiprocessing"></param>
5151
/// <returns></returns>
52-
public Tensors predict(Tensor x,
52+
public Tensors predict(Tensors x,
5353
int batch_size = -1,
5454
int verbose = 0,
5555
int steps = -1,
@@ -115,12 +115,12 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
115115
Tensors run_predict_step(OwnedIterator iterator)
116116
{
117117
var data = iterator.next();
118-
var outputs = predict_step(data[0]);
118+
var outputs = predict_step(data);
119119
tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1));
120120
return outputs;
121121
}
122122

123-
Tensors predict_step(Tensor data)
123+
Tensors predict_step(Tensors data)
124124
{
125125
return Apply(data, training: false);
126126
}

test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs

+13-18
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
1-
using Microsoft.VisualStudio.TestPlatform.Utilities;
2-
using Microsoft.VisualStudio.TestTools.UnitTesting;
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
32
using System;
4-
using System.Collections.Generic;
5-
using System.Linq;
6-
using System.Text;
7-
using System.Threading.Tasks;
8-
using System.Xml.Linq;
9-
using Tensorflow.Operations;
10-
using static Tensorflow.Binding;
11-
using static Tensorflow.KerasApi;
12-
using Tensorflow.NumPy;
13-
using Microsoft.VisualBasic;
14-
using static HDF.PInvoke.H5T;
15-
using Tensorflow.Keras.UnitTest.Helpers;
3+
using Tensorflow;
164
using Tensorflow.Keras.Optimizers;
5+
using Tensorflow.NumPy;
6+
using static Tensorflow.KerasApi;
177

18-
namespace Tensorflow.Keras.UnitTest
8+
namespace TensorFlowNET.Keras.UnitTest
199
{
2010
[TestClass]
2111
public class MultiInputModelTest
2212
{
2313
[TestMethod]
24-
public void SimpleModel()
14+
public void LeNetModel()
2515
{
2616
var inputs = keras.Input((28, 28, 1));
2717
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
@@ -40,7 +30,7 @@ public void SimpleModel()
4030
var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
4131
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
4232
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
43-
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
33+
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
4434
var output = keras.layers.Softmax(-1).Apply(dense3);
4535

4636
var model = keras.Model((inputs, inputs_2), output);
@@ -52,7 +42,7 @@ public void SimpleModel()
5242
{
5343
TrainDir = "mnist",
5444
OneHot = false,
55-
ValidationSize = 59000,
45+
ValidationSize = 59900,
5646
}).Result;
5747

5848
var loss = keras.losses.SparseCategoricalCrossentropy();
@@ -64,6 +54,11 @@ public void SimpleModel()
6454

6555
var x = new NDArray[] { x1, x2 };
6656
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
57+
58+
x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
59+
x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
60+
var pred = model.predict((x1, x2));
61+
Console.WriteLine(pred);
6762
}
6863
}
6964
}

0 commit comments

Comments
 (0)