diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 103d7cfff..324d7e834 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -19,6 +19,8 @@ public class DatasetV2 : IDatasetV2
public TensorSpec[] structure { get; set; }
+ public int FirstInputTensorCount { get; set; } = 1;
+
public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();
public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
@@ -131,6 +133,7 @@ public IDatasetV2 apply_options()
// (4) Apply stats aggregator options
+ dataset.FirstInputTensorCount = this.FirstInputTensorCount;
return dataset;
}
@@ -142,7 +145,7 @@ public override string ToString()
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
$"len: {length}";
- public IEnumerator<(Tensor, Tensor)> GetEnumerator()
+ public IEnumerator<(Tensors, Tensors)> GetEnumerator()
{
using var ownedIterator = new OwnedIterator(this);
@@ -158,7 +161,8 @@ public override string ToString()
break;
}
- yield return (results[0], results.Length == 1 ? null : results[1]);
+ yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
+ null : new Tensors(results.Skip(FirstInputTensorCount)));
}
}
diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
index 5cfeb27cc..320cbe348 100644
--- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -4,7 +4,7 @@
namespace Tensorflow
{
- public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
+ public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)>
{
string[] class_names { get; set; }
@@ -18,6 +18,8 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
TensorSpec[] structure { get; set; }
+ int FirstInputTensorCount { get; set; }
+
///
/// Caches the elements in this dataset.
///
diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
index eb91272c7..1dafc87ea 100644
--- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs
+++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
@@ -27,7 +27,8 @@ void _create_iterator(IDatasetV2 dataset)
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
- (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
+ _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
+ // TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}
@@ -48,7 +49,7 @@ public Tensor[] next()
public void Dispose()
{
- tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
+ //tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
index 8ce1ec655..78882e82d 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataAdapterArgs: IKerasConfig
{
- public Tensor X { get; set; }
- public Tensor Y { get; set; }
+ public Tensors X { get; set; }
+ public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
index fd603a85e..82530e950 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataHandlerArgs: IKerasConfig
{
- public Tensor X { get; set; }
- public Tensor Y { get; set; }
+ public Tensors X { get; set; }
+ public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
index 8bcfcbbbd..e02642dcf 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
@@ -24,6 +24,17 @@ ICallback fit(NDArray x, NDArray y,
int workers = 1,
bool use_multiprocessing = false);
+ ICallback fit(IEnumerable x, NDArray y,
+ int batch_size = -1,
+ int epochs = 1,
+ int verbose = 1,
+ float validation_split = 0f,
+ bool shuffle = true,
+ int initial_epoch = 0,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false);
+
void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
index 53401a444..fd4f93fc1 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
@@ -14,7 +14,76 @@ public void Deconstruct(out byte blue, out byte green, out byte red)
red = data[2];
}
- public static implicit operator NDArray(Array array)
+ public static implicit operator NDArray(int[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(byte[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(float[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(double[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(long[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(bool[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(uint[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(ulong[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(int[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(byte[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(float[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(double[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(long[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(bool[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(uint[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(ulong[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(int[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(byte[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(float[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(double[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(long[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(bool[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(uint[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(ulong[,,] array)
=> new NDArray(array);
public unsafe static implicit operator bool(NDArray nd)
diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
index 6e81216ea..ba7868faa 100644
--- a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
+++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
@@ -25,7 +25,7 @@ private NDArray OpenEntry(ZipArchiveEntry entry)
return array;
using var s = entry.Open();
- return LoadMatrix(s);
+ return (NDArray)LoadMatrix(s);
}
public Array LoadMatrix(Stream stream)
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs
index 3a2cb3ee2..6e4c6b32c 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs
@@ -49,5 +49,8 @@ public IEnumerator GetEnumerator()
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
+
+ public static explicit operator NDArray(Array array)
+ => new NDArray(array);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs
index 9407fd5aa..c7e627772 100644
--- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs
@@ -1,6 +1,9 @@
using System;
+using Tensorflow.Contexts;
+using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
+using Tensorflow.Operations;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -220,6 +223,37 @@ public Tensor model_dataset(Tensor input_dataset,
return (results[0], results[1]);
}
+ public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null)
+ {
+ var ctx = tf.Context;
+ Dictionary attrs = new();
+ attrs["output_types"] = output_types;
+ attrs["output_shapes"] = output_shapes;
+ if (ctx.executing_eagerly())
+ {
+ try
+ {
+ var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name)
+ {
+ attrs = attrs
+ });
+ return result[0];
+ }
+ catch (Exception)
+ {
+ return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx);
+ }
+ }
+ return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0];
+ }
+
+ public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx)
+ {
+ object[] attrs = new object[] { output_types, output_shapes };
+ var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name);
+ return result[0];
+ }
+
///
/// Makes a new iterator from the given `dataset` and stores it in `iterator`.
///
diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs
index ecd844d1f..7fa4dd443 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensors.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs
@@ -65,6 +65,93 @@ public void Insert(int index, Tensor tensor)
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
+ public NDArray numpy()
+ {
+ EnsureSingleTensor(this, "nnumpy");
+ return this[0].numpy();
+ }
+
+ public T[] ToArray() where T: unmanaged
+ {
+ EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
+ return this[0].ToArray();
+ }
+
+ #region Explicit Conversions
+ public unsafe static explicit operator bool(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to bool");
+ return (bool)tensor[0];
+ }
+
+ public unsafe static explicit operator sbyte(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to sbyte");
+ return (sbyte)tensor[0];
+ }
+
+ public unsafe static explicit operator byte(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to byte");
+ return (byte)tensor[0];
+ }
+
+ public unsafe static explicit operator ushort(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to ushort");
+ return (ushort)tensor[0];
+ }
+
+ public unsafe static explicit operator short(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to short");
+ return (short)tensor[0];
+ }
+
+ public unsafe static explicit operator int(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to int");
+ return (int)tensor[0];
+ }
+
+ public unsafe static explicit operator uint(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to uint");
+ return (uint)tensor[0];
+ }
+
+ public unsafe static explicit operator long(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to long");
+ return (long)tensor[0];
+ }
+
+ public unsafe static explicit operator ulong(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to ulong");
+ return (ulong)tensor[0];
+ }
+
+ public unsafe static explicit operator float(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to byte");
+ return (byte)tensor[0];
+ }
+
+ public unsafe static explicit operator double(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to double");
+ return (double)tensor[0];
+ }
+
+ public unsafe static explicit operator string(Tensors tensor)
+ {
+ EnsureSingleTensor(tensor, "explicit conversion to string");
+ return (string)tensor[0];
+ }
+ #endregion
+
+ #region Implicit Conversions
public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);
@@ -87,12 +174,26 @@ public static implicit operator Tensor(Tensors tensors)
public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();
+ #endregion
+
public void Deconstruct(out Tensor a, out Tensor b)
{
a = items[0];
b = items[1];
}
+ private static void EnsureSingleTensor(Tensors tensors, string methodnName)
+ {
+ if(tensors.Length == 0)
+ {
+ throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
+ }
+ else if(tensors.Length > 1)
+ {
+ throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
+ }
+ }
+
public override string ToString()
=> items.Count() == 1
? items.First().ToString()
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
index 3314f5c40..6c7d53b2f 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
@@ -10,7 +10,7 @@ public abstract class DataAdapter
protected DataAdapterArgs args;
protected IDatasetV2 dataset;
- public virtual bool CanHandle(Tensor x, Tensor y = null)
+ public virtual bool CanHandle(Tensors x, Tensors y = null)
=> throw new NotImplementedException();
public virtual IDatasetV2 GetDataset()
@@ -19,12 +19,18 @@ public virtual IDatasetV2 GetDataset()
public virtual int GetSize()
=> throw new NotImplementedException("");
- public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
+ public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
{
- if (x.shape.ndim == 1)
- x = array_ops.expand_dims(x, axis: -1);
- if (y.shape.ndim == 1)
- y = array_ops.expand_dims(y, axis: -1);
+ for(int i = 0; i < x.Length; i++)
+ {
+ if (x[i].shape.ndim == 1)
+ x[i] = array_ops.expand_dims(x[i], axis: -1);
+ }
+ for (int i = 0; i < y.Length; i++)
+ {
+ if (y[i].shape.ndim == 1)
+ y[i] = array_ops.expand_dims(y[i], axis: -1);
+ }
return (x, y);
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
index 1ddddd111..4723222f2 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
@@ -93,11 +93,15 @@ long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
{
+ var data_iterator = new OwnedIterator(_dataset);
foreach (var epoch in range(_initial_epoch, _epochs))
{
if (_insufficient_data)
break;
- using var data_iterator = new OwnedIterator(_dataset);
+ if (_adapter.ShouldRecreateIterator())
+ {
+ data_iterator = new OwnedIterator(_dataset);
+ }
yield return (epoch, data_iterator);
}
// _adapter.on_epoch_end()
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
index df414b9fd..4bdc49795 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
@@ -13,10 +13,10 @@ public interface IDataAdapter
/// input features
/// target labels
///
- bool CanHandle(Tensor x, Tensor y = null);
+ bool CanHandle(Tensors x, Tensors y = null);
IDatasetV2 GetDataset();
int GetSize();
- (Tensor, Tensor) Expand1d(Tensor x, Tensor y);
+ (Tensors, Tensors) Expand1d(Tensors x, Tensors y);
bool ShouldRecreateIterator();
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index fc61aa715..a7e1d7e34 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -1,4 +1,5 @@
using System;
+using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
@@ -33,10 +34,11 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
var inputs = new Tensors();
if (args.X != null)
- inputs.Add(args.X);
+ inputs.AddRange(args.X);
if (args.Y != null)
- inputs.Add(args.Y);
+ inputs.AddRange(args.Y);
dataset = slice_inputs(indices_dataset, inputs);
+ dataset.FirstInputTensorCount = args.X.Length;
}
Tensors permutation(Tensors tensor)
@@ -87,8 +89,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
return dataset.with_options(new DatasetOptions { });
}
- public override int GetSize()
- => _size;
+ public override int GetSize() => _size;
+
+ public override bool ShouldRecreateIterator() => false;
void _process_tensorlike()
{
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
index 1ebd56d33..39004183b 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
@@ -59,7 +59,62 @@ public ICallback fit(NDArray x, NDArray y,
StepsPerExecution = _steps_per_execution
});
- return FitInternal(data_handler, epochs, verbose);
+ return FitInternal(data_handler, epochs, verbose, validation_data: null,
+ train_step_func: train_step_function);
+ }
+
+ public ICallback fit(IEnumerable x, NDArray y,
+ int batch_size = -1,
+ int epochs = 1,
+ int verbose = 1,
+ float validation_split = 0f,
+ bool shuffle = true,
+ int initial_epoch = 0,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false)
+ {
+ foreach(var tx in x)
+ {
+ if (tx.dims[0] != y.dims[0])
+ {
+ throw new InvalidArgumentError(
+ $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
+ }
+ }
+ int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
+
+ var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor);
+ var train_y = y[new Slice(0, train_count)];
+ var val_x = x.Select(x => x[new Slice(train_count)] as Tensor);
+ var val_y = y[new Slice(train_count)];
+
+ var data_handler = new DataHandler(new DataHandlerArgs
+ {
+ X = new Tensors(train_x),
+ Y = train_y,
+ BatchSize = batch_size,
+ InitialEpoch = initial_epoch,
+ Epochs = epochs,
+ Shuffle = shuffle,
+ MaxQueueSize = max_queue_size,
+ Workers = workers,
+ UseMultiprocessing = use_multiprocessing,
+ Model = this,
+ StepsPerExecution = _steps_per_execution
+ });
+
+ if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
+ data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
+ {
+ return FitInternal(data_handler, epochs, verbose, validation_data: null,
+ train_step_func: train_step_multi_inputs_function);
+ }
+ else
+ {
+ return FitInternal(data_handler, epochs, verbose, validation_data: null,
+ train_step_func: train_step_function);
+ }
}
public History fit(IDatasetV2 dataset,
@@ -88,10 +143,12 @@ public History fit(IDatasetV2 dataset,
StepsPerExecution = _steps_per_execution
});
- return FitInternal(data_handler, epochs, verbose, validation_data: validation_data);
+ return FitInternal(data_handler, epochs, verbose, validation_data: validation_data,
+ train_step_func: train_step_function);
}
- History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null)
+ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data,
+ Func> train_step_func)
{
stop_training = false;
_train_counter.assign(0);
@@ -113,7 +170,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
- logs = train_step_function(data_handler, iterator);
+ logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
callbacks.on_train_batch_end(end_step, logs);
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
index 8d85d70de..d8171e2a9 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
@@ -17,12 +17,21 @@ Dictionary train_step_function(DataHandler data_handler, OwnedIte
return outputs;
}
+ Dictionary train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
+ {
+ var data = iterator.next();
+ var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
+ var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
+ tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
+ return outputs;
+ }
+
///
/// The logic for one training step.
///
///
///
- Dictionary train_step(DataHandler data_handler, Tensor x, Tensor y)
+ Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
diff --git a/test/TensorFlowNET.Keras.UnitTest/Gradient.cs b/test/TensorFlowNET.Keras.UnitTest/Gradient.cs
index fad8e1187..f20eae0e0 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Gradient.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Gradient.cs
@@ -9,7 +9,7 @@
namespace TensorFlowNET.Keras.UnitTest;
[TestClass]
-public class GradientTest
+public class GradientTest : EagerModeTestBase
{
public IModel get_actor(int num_states)
{
diff --git a/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs
new file mode 100644
index 000000000..e145ce585
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.UnitTest.Helpers
+{
+ public class RandomDataSet : DataSetBase
+ {
+ private Shape _shape;
+
+ public RandomDataSet(Shape shape, int count)
+ {
+ _shape = shape;
+ Debug.Assert(_shape.ndim == 3);
+ long[] dims = new long[4];
+ dims[0] = count;
+ for (int i = 1; i < 4; i++)
+ {
+ dims[i] = _shape[i - 1];
+ }
+ Shape s = new Shape(dims);
+ Data = np.random.normal(0, 2, s);
+ Labels = np.random.uniform(0, 1, (count, 1));
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
new file mode 100644
index 000000000..490178bc9
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
@@ -0,0 +1,69 @@
+using Microsoft.VisualStudio.TestPlatform.Utilities;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using System.Xml.Linq;
+using Tensorflow.Operations;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+using Tensorflow.NumPy;
+using Microsoft.VisualBasic;
+using static HDF.PInvoke.H5T;
+using Tensorflow.Keras.UnitTest.Helpers;
+using Tensorflow.Keras.Optimizers;
+
+namespace Tensorflow.Keras.UnitTest
+{
+ [TestClass]
+ public class MultiInputModelTest
+ {
+ [TestMethod]
+ public void SimpleModel()
+ {
+ var inputs = keras.Input((28, 28, 1));
+ var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
+ var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
+ var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
+ var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
+ var flat1 = keras.layers.Flatten().Apply(pool2);
+
+ var inputs_2 = keras.Input((28, 28, 1));
+ var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
+ var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
+ var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
+ var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
+ var flat1_2 = keras.layers.Flatten().Apply(pool2_2);
+
+ var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
+ var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
+ var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
+ var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
+ var output = keras.layers.Softmax(-1).Apply(dense3);
+
+ var model = keras.Model((inputs, inputs_2), output);
+ model.summary();
+
+ var data_loader = new MnistModelLoader();
+
+ var dataset = data_loader.LoadAsync(new ModelLoadSetting
+ {
+ TrainDir = "mnist",
+ OneHot = false,
+ ValidationSize = 59000,
+ }).Result;
+
+ var loss = keras.losses.SparseCategoricalCrossentropy();
+ var optimizer = new Adam(0.001f);
+ model.compile(optimizer, loss, new string[] { "accuracy" });
+
+ NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
+ NDArray x2 = x1;
+
+ var x = new NDArray[] { x1, x2 };
+ model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
index e778a5a4a..385ec0f7c 100644
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
@@ -13,6 +13,7 @@
using Tensorflow.Keras.Optimizers;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
+using Tensorflow.Keras.UnitTest.Helpers;
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave;
namespace TensorFlowNET.Keras.UnitTest.SaveModel;
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
index 5b7c2b62e..251afde3d 100644
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
@@ -6,7 +6,7 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
-using Tensorflow.NumPy;
+using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -175,24 +175,4 @@ public void AlexnetFromSequential()
// )
#endregion
}
-
- public class RandomDataSet : DataSetBase
- {
- private Shape _shape;
-
- public RandomDataSet(Shape shape, int count)
- {
- _shape = shape;
- Debug.Assert(_shape.ndim == 3);
- long[] dims = new long[4];
- dims[0] = count;
- for (int i = 1; i < 4; i++)
- {
- dims[i] = _shape[i - 1];
- }
- Shape s = new Shape(dims);
- Data = np.random.normal(0, 2, s);
- Labels = np.random.uniform(0, 1, (count, 1));
- }
- }
}
\ No newline at end of file