Skip to content

Commit 0ee50d3

Browse files
committed
Add double to NDArrayConverter.
1 parent 271dcef commit 0ee50d3

File tree

3 files changed

+45
-53
lines changed

3 files changed

+45
-53
lines changed

src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
1414
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
1515
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),
1616
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data),
17-
_ => throw new NotImplementedException("")
17+
TF_DataType.TF_DOUBLE => Scalar<T>(*(double*)nd.data),
18+
_ => throw new NotImplementedException(nameof(NDArrayConverter))
1819
};
1920

2021
static T Scalar<T>(byte input)
@@ -23,7 +24,8 @@ static T Scalar<T>(byte input)
2324
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
2425
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
2526
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
26-
_ => throw new NotImplementedException("")
27+
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
28+
_ => throw new NotImplementedException(nameof(NDArrayConverter))
2729
};
2830

2931
static T Scalar<T>(float input)
@@ -32,7 +34,8 @@ static T Scalar<T>(float input)
3234
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
3335
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
3436
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
35-
_ => throw new NotImplementedException("")
37+
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
38+
_ => throw new NotImplementedException(nameof(NDArrayConverter))
3639
};
3740

3841
static T Scalar<T>(int input)
@@ -41,7 +44,8 @@ static T Scalar<T>(int input)
4144
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
4245
TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64),
4346
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
44-
_ => throw new NotImplementedException("")
47+
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
48+
_ => throw new NotImplementedException(nameof(NDArrayConverter))
4549
};
4650

4751
static T Scalar<T>(long input)
@@ -50,7 +54,8 @@ static T Scalar<T>(long input)
5054
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
5155
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
5256
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
53-
_ => throw new NotImplementedException("")
57+
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
58+
_ => throw new NotImplementedException(nameof(NDArrayConverter))
5459
};
5560

5661
public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
@@ -65,7 +70,7 @@ public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
6570
T[,,,] array => Addr(array),
6671
T[,,,,] array => Addr(array),
6772
T[,,,,,] array => Addr(array),
68-
_ => throw new NotImplementedException("")
73+
_ => throw new NotImplementedException(nameof(NDArrayConverter))
6974
};
7075

7176
System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize);

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using Tensorflow.NumPy;
2-
using System;
1+
using System;
32
using System.Collections.Generic;
43
using System.Linq;
54
using Tensorflow.Keras.ArgsDefinition;
@@ -33,40 +32,7 @@ public Tensors predict(IDatasetV2 dataset,
3332
StepsPerExecution = _steps_per_execution
3433
});
3534

36-
var callbacks = new CallbackList(new CallbackParams
37-
{
38-
Model = this,
39-
Verbose = verbose,
40-
Epochs = 1,
41-
Steps = data_handler.Inferredsteps
42-
});
43-
44-
Tensor batch_outputs = null;
45-
_predict_counter.assign(0);
46-
callbacks.on_predict_begin();
47-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
48-
{
49-
foreach (var step in data_handler.steps())
50-
{
51-
callbacks.on_predict_batch_begin(step);
52-
var tmp_batch_outputs = run_predict_step(iterator);
53-
if (batch_outputs == null)
54-
{
55-
batch_outputs = tmp_batch_outputs[0];
56-
}
57-
else
58-
{
59-
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
60-
}
61-
62-
var end_step = step + data_handler.StepIncrement;
63-
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
64-
}
65-
GC.Collect();
66-
}
67-
68-
callbacks.on_predict_end();
69-
return batch_outputs;
35+
return PredictInternal(data_handler, verbose);
7036
}
7137

7238
/// <summary>
@@ -105,23 +71,45 @@ public Tensors predict(Tensor x,
10571
StepsPerExecution = _steps_per_execution
10672
});
10773

108-
Tensors outputs = null;
74+
return PredictInternal(data_handler, verbose);
75+
}
76+
77+
Tensors PredictInternal(DataHandler data_handler, int verbose)
78+
{
79+
var callbacks = new CallbackList(new CallbackParams
80+
{
81+
Model = this,
82+
Verbose = verbose,
83+
Epochs = 1,
84+
Steps = data_handler.Inferredsteps
85+
});
86+
87+
Tensor batch_outputs = null;
10988
_predict_counter.assign(0);
110-
// callbacks.on_predict_begin()
89+
callbacks.on_predict_begin();
11190
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
11291
{
113-
foreach(var step in data_handler.steps())
92+
foreach (var step in data_handler.steps())
11493
{
115-
// callbacks.on_predict_batch_begin(step)
116-
var batch_outputs = run_predict_step(iterator);
117-
outputs = batch_outputs;
94+
callbacks.on_predict_batch_begin(step);
95+
var tmp_batch_outputs = run_predict_step(iterator);
96+
if (batch_outputs == null)
97+
{
98+
batch_outputs = tmp_batch_outputs[0];
99+
}
100+
else
101+
{
102+
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
103+
}
104+
118105
var end_step = step + data_handler.StepIncrement;
119-
// callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
106+
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
120107
}
121-
GC.Collect();
122108
}
123-
// callbacks.on_predict_end()
124-
return outputs;
109+
110+
callbacks.on_predict_end();
111+
112+
return batch_outputs;
125113
}
126114

127115
Tensors run_predict_step(OwnedIterator iterator)

src/TensorFlowNET.Keras/Engine/Model.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ public partial class Model : Layer, IModel
3636
IVariableV1 _predict_counter;
3737
bool _base_model_initialized;
3838
bool stop_training;
39-
DataHandler data_handler;
4039

4140
public OptimizerV2 Optimizer
4241
{

0 commit comments

Comments
 (0)