Skip to content

Commit 9b11d45

Browse files
committed
Fix NeuralNetXorKeras accuracy. #952
1 parent 3370723 commit 9b11d45

File tree

7 files changed

+92
-21
lines changed

7 files changed

+92
-21
lines changed

src/TensorFlowNET.Keras/Engine/Functional.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs)
7171
NodesByDepth = nodes_by_depth;
7272
if (_layers.Count == 0)
7373
_layers = layers;
74-
74+
_self_tracked_trackables = layers;
7575
// Build self.input_names and self.output_names.
7676
_set_output_names();
7777

src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ protected virtual IVariableV1 add_weight(string name,
5353

5454
//backend.track_variable(variable);
5555
if (trainable == true)
56-
trainable_weights.Add(variable);
56+
_trainable_weights.Add(variable);
5757
else
58-
non_trainable_weights.Add(variable);
58+
_non_trainable_weights.Add(variable);
5959

6060
return variable;
6161
}

src/TensorFlowNET.Keras/Engine/Layer.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ public abstract partial class Layer : AutoTrackable, ILayer
6161
protected InputSpec inputSpec;
6262
bool dynamic = true;
6363
public bool SupportsMasking { get; set; }
64-
protected List<IVariableV1> trainable_weights;
64+
protected List<IVariableV1> _trainable_weights;
6565

66-
public virtual List<IVariableV1> trainable_variables => trainable_weights;
66+
public virtual List<IVariableV1> trainable_variables => _trainable_weights;
6767

68-
protected List<IVariableV1> non_trainable_weights;
69-
public List<IVariableV1> non_trainable_variables => non_trainable_weights;
68+
protected List<IVariableV1> _non_trainable_weights;
69+
public List<IVariableV1> non_trainable_variables => _non_trainable_weights;
7070

7171
protected int id;
7272
public int Id => id;
@@ -104,8 +104,8 @@ public Layer(LayerArgs args)
104104

105105
id = ops.uid_layer();
106106
_init_set_name(args.Name);
107-
trainable_weights = new List<IVariableV1>();
108-
non_trainable_weights = new List<IVariableV1>();
107+
_trainable_weights = new List<IVariableV1>();
108+
_non_trainable_weights = new List<IVariableV1>();
109109
computePreviousMask = false;
110110
updates = new List<Operation>();
111111
_self_tracked_trackables = new List<ILayer>();
@@ -254,15 +254,15 @@ List<IVariableV1> ILayer.trainable_weights
254254
{
255255
get
256256
{
257-
return trainable_weights;
257+
return _trainable_weights;
258258
}
259259
}
260260

261261
List<IVariableV1> ILayer.non_trainable_weights
262262
{
263263
get
264264
{
265-
return non_trainable_weights;
265+
return _non_trainable_weights;
266266
}
267267
}
268268

@@ -271,8 +271,8 @@ public List<IVariableV1> weights
271271
get
272272
{
273273
var weights = new List<IVariableV1>();
274-
weights.AddRange(trainable_weights);
275-
weights.AddRange(non_trainable_weights);
274+
weights.AddRange(_trainable_weights);
275+
weights.AddRange(_non_trainable_weights);
276276
return weights;
277277
}
278278
set

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+40-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq;
55
using Tensorflow.Keras.ArgsDefinition;
66
using Tensorflow.Keras.Engine.DataAdapters;
7+
using System.Diagnostics;
78

89
namespace Tensorflow.Keras.Engine
910
{
@@ -87,25 +88,57 @@ void FitInternal(int epochs, int verbose)
8788
{
8889
stop_training = false;
8990
_train_counter.assign(0);
91+
Stopwatch sw = new Stopwatch();
9092
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
9193
{
9294
reset_metrics();
93-
// callbacks.on_epoch_begin(epoch)
95+
on_epoch_begin(epoch, epochs);
9496
// data_handler.catch_stop_iteration();
9597
foreach (var step in data_handler.steps())
9698
{
97-
// callbacks.on_train_batch_begin(step)
99+
sw.Start();
98100
var results = train_step_function(iterator);
99-
if (verbose == 1)
101+
sw.Stop();
102+
on_train_batch_begin(verbose, step, sw.ElapsedMilliseconds, results);
103+
104+
// recycle memory more frequency
105+
if (sw.ElapsedMilliseconds > 100)
100106
{
101-
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
102-
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
107+
GC.Collect();
103108
}
104-
105-
GC.Collect();
109+
sw.Reset();
106110
}
111+
Console.WriteLine();
112+
113+
GC.Collect();
107114
GC.WaitForPendingFinalizers();
108115
}
109116
}
117+
118+
void on_epoch_begin(int epoch, int epochs)
119+
{
120+
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}");
121+
}
122+
123+
void on_train_batch_begin(int verbose, long step, long elapse, IEnumerable<(string, Tensor)> results)
124+
{
125+
if (verbose == 1)
126+
{
127+
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
128+
129+
var progress = "";
130+
for (int i = 0; i < step + 1; i++)
131+
for (int j = 0; j < 30 / data_handler.Inferredsteps; j++)
132+
progress += "=";
133+
progress += ">";
134+
135+
var remaining = "";
136+
for (int i = 1; i < 30 - progress.Length; i++)
137+
remaining += ".";
138+
139+
Binding.tf_output_redirect.Write($"{step + 1:D4}/{data_handler.Inferredsteps:D4} [{progress}{remaining}] - {elapse}ms/step {result_pairs}");
140+
Console.CursorLeft = 0;
141+
}
142+
}
110143
}
111144
}

src/TensorFlowNET.Keras/Engine/Model.cs

+15
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,26 @@ public override List<IVariableV1> trainable_variables
7575
get
7676
{
7777
var variables = new List<IVariableV1>();
78+
79+
if (!Trainable)
80+
{
81+
return variables;
82+
}
83+
84+
foreach (var trackable_obj in _self_tracked_trackables)
85+
{
86+
if (trackable_obj.Trainable)
87+
variables.AddRange(trackable_obj.trainable_variables);
88+
}
89+
7890
foreach (var layer in _layers)
7991
{
8092
if (layer.Trainable)
8193
variables.AddRange(layer.trainable_variables);
8294
}
95+
96+
// variables.AddRange(_trainable_weights);
97+
8398
return variables;
8499
}
85100
}

src/python/.vscode/launch.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"name": "Python: Current File",
99
"type": "python",
1010
"request": "launch",
11-
"program": "${file}",
11+
"program": "${workspaceFolder}/xor_keras.py",
1212
"console": "integratedTerminal",
1313
"justMyCode": false
1414
}

src/python/xor_keras.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6+
print(tf.__version__)
7+
# tf.compat.v1.enable_eager_execution()
8+
# tf.debugging.set_log_device_placement(True);
9+
tf.config.run_functions_eagerly(True)
10+
11+
x = np.array([[ 0, 0 ], [ 0, 1 ], [ 1, 0 ], [ 1, 1 ]])
12+
y = np.array([[ 0 ], [ 1 ], [ 1 ], [ 0 ] ])
13+
14+
model = tf.keras.Sequential()
15+
model.add(tf.keras.Input(2))
16+
model.add(tf.keras.layers.Dense(32, "relu"))
17+
model.add(tf.keras.layers.Dense(1, "sigmoid"))
18+
model.compile(optimizer = tf.keras.optimizers.Adam(),
19+
loss = tf.keras.losses.MeanSquaredError(),
20+
metrics = ["accuracy"])
21+
model.fit(x, y, 1, 100)
22+
result = model.evaluate(x, y)
23+
print(model.predict(x, 4))

0 commit comments

Comments
 (0)