Skip to content

Commit 85aba7a

Browse files
authored
Merge pull request #1014 from Wanglongzhi2001/master
Fix validation_split has no output and add validation_data parameter to model.fit
2 parents 86eb48b + fb1a863 commit 85aba7a

File tree

8 files changed

+234
-37
lines changed

8 files changed

+234
-37
lines changed

src/TensorFlowNET.Core/Keras/Engine/ICallback.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ public interface ICallback
1515
void on_predict_end();
1616
void on_test_begin();
1717
void on_test_batch_begin(long step);
18-
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs);
18+
void on_test_batch_end(long end_step, Dictionary<string, float> logs);
1919
}

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ICallback fit(NDArray x, NDArray y,
2222
int verbose = 1,
2323
List<ICallback> callbacks = null,
2424
float validation_split = 0f,
25+
(NDArray val_x, NDArray val_y)? validation_data = null,
2526
bool shuffle = true,
2627
int initial_epoch = 0,
2728
int max_queue_size = 10,
@@ -34,6 +35,7 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
3435
int verbose = 1,
3536
List<ICallback> callbacks = null,
3637
float validation_split = 0f,
38+
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
3739
bool shuffle = true,
3840
int initial_epoch = 0,
3941
int max_queue_size = 10,
@@ -65,7 +67,8 @@ Dictionary<string, float> evaluate(NDArray x, NDArray y,
6567
int max_queue_size = 10,
6668
int workers = 1,
6769
bool use_multiprocessing = false,
68-
bool return_dict = false);
70+
bool return_dict = false,
71+
bool is_val = false);
6972

7073
Tensors predict(Tensors x,
7174
int batch_size = -1,

src/TensorFlowNET.Keras/Callbacks/CallbackList.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void on_test_batch_begin(long step)
6969
{
7070
callbacks.ForEach(x => x.on_test_batch_begin(step));
7171
}
72-
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
72+
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
7373
{
7474
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
7575
}

src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs
121121
public void on_predict_end() { }
122122
public void on_test_begin() { }
123123
public void on_test_batch_begin(long step) { }
124-
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }
124+
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) { }
125125

126126
float get_monitor_value(Dictionary<string, float> logs)
127127
{

src/TensorFlowNET.Keras/Callbacks/History.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
4848
{
4949
history[log.Key] = new List<float>();
5050
}
51-
history[log.Key].Add((float)log.Value);
51+
history[log.Key].Add(log.Value);
5252
}
5353
}
5454

@@ -78,7 +78,7 @@ public void on_test_batch_begin(long step)
7878

7979
}
8080

81-
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
81+
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
8282
{
8383
}
8484
}

src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ public void on_test_batch_begin(long step)
105105
{
106106
_sw.Restart();
107107
}
108-
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
108+
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
109109
{
110110
_sw.Stop();
111111
var elapse = _sw.ElapsedMilliseconds;
112-
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));
112+
var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}"));
113113

114114
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
115115
if (!Console.IsOutputRedirected)

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+71-13
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ public partial class Model
2626
/// <param name="workers"></param>
2727
/// <param name="use_multiprocessing"></param>
2828
/// <param name="return_dict"></param>
29+
/// <param name="is_val"></param>
2930
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
3031
int batch_size = -1,
3132
int verbose = 1,
3233
int steps = -1,
3334
int max_queue_size = 10,
3435
int workers = 1,
3536
bool use_multiprocessing = false,
36-
bool return_dict = false)
37+
bool return_dict = false,
38+
bool is_val = false
39+
)
3740
{
3841
if (x.dims[0] != y.dims[0])
3942
{
@@ -63,31 +66,76 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
6366
});
6467
callbacks.on_test_begin();
6568

66-
IEnumerable<(string, Tensor)> logs = null;
69+
//Dictionary<string, float>? logs = null;
70+
var logs = new Dictionary<string, float>();
6771
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
6872
{
6973
reset_metrics();
70-
callbacks.on_epoch_begin(epoch);
7174
// data_handler.catch_stop_iteration();
7275

7376
foreach (var step in data_handler.steps())
7477
{
7578
callbacks.on_test_batch_begin(step);
7679
logs = test_function(data_handler, iterator);
7780
var end_step = step + data_handler.StepIncrement;
78-
callbacks.on_test_batch_end(end_step, logs);
81+
if (is_val == false)
82+
callbacks.on_test_batch_end(end_step, logs);
7983
}
8084
}
8185

8286
var results = new Dictionary<string, float>();
8387
foreach (var log in logs)
8488
{
85-
results[log.Item1] = (float)log.Item2;
89+
results[log.Key] = log.Value;
8690
}
8791
return results;
8892
}
8993

90-
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
94+
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
95+
{
96+
var data_handler = new DataHandler(new DataHandlerArgs
97+
{
98+
X = new Tensors(x),
99+
Y = y,
100+
Model = this,
101+
StepsPerExecution = _steps_per_execution
102+
});
103+
104+
var callbacks = new CallbackList(new CallbackParams
105+
{
106+
Model = this,
107+
Verbose = verbose,
108+
Steps = data_handler.Inferredsteps
109+
});
110+
callbacks.on_test_begin();
111+
112+
Dictionary<string, float> logs = null;
113+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
114+
{
115+
reset_metrics();
116+
callbacks.on_epoch_begin(epoch);
117+
// data_handler.catch_stop_iteration();
118+
119+
foreach (var step in data_handler.steps())
120+
{
121+
callbacks.on_test_batch_begin(step);
122+
logs = test_step_multi_inputs_function(data_handler, iterator);
123+
var end_step = step + data_handler.StepIncrement;
124+
if (is_val == false)
125+
callbacks.on_test_batch_end(end_step, logs);
126+
}
127+
}
128+
129+
var results = new Dictionary<string, float>();
130+
foreach (var log in logs)
131+
{
132+
results[log.Key] = log.Value;
133+
}
134+
return results;
135+
}
136+
137+
138+
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
91139
{
92140
var data_handler = new DataHandler(new DataHandlerArgs
93141
{
@@ -104,7 +152,7 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
104152
});
105153
callbacks.on_test_begin();
106154

107-
IEnumerable<(string, Tensor)> logs = null;
155+
Dictionary<string, float> logs = null;
108156
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
109157
{
110158
reset_metrics();
@@ -113,36 +161,46 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
113161

114162
foreach (var step in data_handler.steps())
115163
{
116-
// callbacks.on_train_batch_begin(step)
164+
callbacks.on_test_batch_begin(step);
117165
logs = test_function(data_handler, iterator);
166+
var end_step = step + data_handler.StepIncrement;
167+
if (is_val == false)
168+
callbacks.on_test_batch_end(end_step, logs);
118169
}
119170
}
120171

121172
var results = new Dictionary<string, float>();
122173
foreach (var log in logs)
123174
{
124-
results[log.Item1] = (float)log.Item2;
175+
results[log.Key] = log.Value;
125176
}
126177
return results;
127178
}
128179

129-
IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator)
180+
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
130181
{
131182
var data = iterator.next();
132183
var outputs = test_step(data_handler, data[0], data[1]);
133184
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
134185
return outputs;
135186
}
136-
137-
List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y)
187+
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
188+
{
189+
var data = iterator.next();
190+
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
191+
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
192+
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
193+
return outputs;
194+
}
195+
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
138196
{
139197
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
140198
var y_pred = Apply(x, training: false);
141199
var loss = compiled_loss.Call(y, y_pred);
142200

143201
compiled_metrics.update_state(y, y_pred);
144202

145-
return metrics.Select(x => (x.Name, x.result())).ToList();
203+
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
146204
}
147205
}
148206
}

0 commit comments

Comments
 (0)