Skip to content

Commit 476f8cf

Browse files
committed
add predict for logistic regression.
1 parent 5a2433c commit 476f8cf

File tree

12 files changed

+196
-18
lines changed

12 files changed

+196
-18
lines changed

docs/source/Graph.md

+58
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,61 @@ A typical graph is looks like below:
2121

2222
![image](../assets/graph_vis_animation.gif)
2323

24+
25+
26+
### Save Model
27+
28+
Saving the model means saving all the values of the parameters and the graph.
29+
30+
```python
31+
saver = tf.train.Saver()
32+
saver.save(sess,'./tensorflowModel.ckpt')
33+
```
34+
35+
After saving the model there will be four files:
36+
37+
* tensorflowModel.ckpt.meta:
38+
* tensorflowModel.ckpt.data-00000-of-00001:
39+
* tensorflowModel.ckpt.index
40+
* checkpoint
41+
42+
We also created a protocol buffer file .pbtxt. It is human readable if you want to convert it to binary: `as_text: false`.
43+
44+
* tensorflowModel.pbtxt:
45+
46+
This holds a network of nodes, each representing one operation, connected to each other as inputs and outputs.
47+
48+
49+
50+
### Freezing the Graph
51+
52+
##### *Why we need it?*
53+
54+
When we need to keep all the values of the variables and the Graph structure in a single file we have to freeze the graph.
55+
56+
```csharp
57+
from tensorflow.python.tools import freeze_graph
58+
59+
freeze_graph.freeze_graph(input_graph = 'logistic_regression/tensorflowModel.pbtxt',
60+
input_saver = "",
61+
input_binary = False,
62+
input_checkpoint = 'logistic_regression/tensorflowModel.ckpt',
63+
output_node_names = "Softmax",
64+
restore_op_name = "save/restore_all",
65+
filename_tensor_name = "save/Const:0",
66+
output_graph = 'frozentensorflowModel.pb',
67+
clear_devices = True,
68+
initializer_nodes = "")
69+
70+
```
71+
72+
### Optimizing for Inference
73+
74+
To Reduce the amount of computation needed when the network is used only for inferences we can remove some parts of a graph that are only needed for training.
75+
76+
77+
78+
### Restoring the Model
79+
80+
81+

docs/source/LogisticRegression.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Chapter. Logistic Regression
2+
3+
### What is logistic regression?
4+
5+
6+
7+
The full example is [here](https://github.com/SciSharp/TensorFlow.NET/blob/master/test/TensorFlowNET.Examples/LogisticRegression.cs).

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ Welcome to TensorFlow.NET's documentation!
2626
Train
2727
EagerMode
2828
LinearRegression
29+
LogisticRegression
2930
ImageRecognition

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def = nu
184184

185185
// Adds graph_def or the default.
186186
if (graph_def == null)
187-
meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true);
187+
meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true);
188188
else
189189
meta_graph_def.GraphDef = graph_def;
190190

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class FreezeGraph
8+
{
9+
public static void freeze_graph(string input_graph,
10+
string input_saver,
11+
bool input_binary,
12+
string input_checkpoint,
13+
string output_node_names,
14+
string restore_op_name,
15+
string filename_tensor_name,
16+
string output_graph,
17+
bool clear_devices,
18+
string initializer_nodes)
19+
{
20+
21+
}
22+
}
23+
}

src/TensorFlowNET.Core/Graphs/Graph.Export.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public Buffer ToGraphDef(Status s)
1818
return buffer;
1919
}
2020

21-
public GraphDef _as_graph_def(bool add_shapes = false)
21+
private GraphDef _as_graph_def(bool add_shapes = false)
2222
{
2323
var buffer = ToGraphDef(Status);
2424
Status.Check();
@@ -30,5 +30,8 @@ public GraphDef _as_graph_def(bool add_shapes = false)
3030

3131
return def;
3232
}
33+
34+
public GraphDef as_graph_def(bool add_shapes = false)
35+
=> _as_graph_def(add_shapes);
3336
}
3437
}

src/TensorFlowNET.Core/Graphs/graph_io.py.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class graph_io
1010
{
1111
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
1212
{
13-
var graph_def = graph._as_graph_def();
13+
var graph_def = graph.as_graph_def();
1414
string path = Path.Combine(logdir, name);
1515
if (as_text)
1616
File.WriteAllText(path, graph_def.ToString());

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

+17-10
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,24 @@ public NDArray build_results(BaseSession session, NDArray[] tensor_values)
5858
{
5959
var value = tensor_values[j];
6060
j += 1;
61-
switch (value.dtype.Name)
61+
if (value.ndim == 2)
6262
{
63-
case "Int32":
64-
full_values.Add(value.Data<int>(0));
65-
break;
66-
case "Single":
67-
full_values.Add(value.Data<float>(0));
68-
break;
69-
case "Double":
70-
full_values.Add(value.Data<double>(0));
71-
break;
63+
full_values.Add(value[0]);
64+
}
65+
else
66+
{
67+
switch (value.dtype.Name)
68+
{
69+
case "Int32":
70+
full_values.Add(value.Data<int>(0));
71+
break;
72+
case "Single":
73+
full_values.Add(value.Data<float>(0));
74+
break;
75+
case "Double":
76+
full_values.Add(value.Data<double>(0));
77+
break;
78+
}
7279
}
7380
}
7481
i += 1;

src/TensorFlowNET.Core/Train/Saving/Saver.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public MetaGraphDef export_meta_graph(string filename= "",
251251
{
252252
return export_meta_graph(
253253
filename: filename,
254-
graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true),
254+
graph_def: ops.get_default_graph().as_graph_def(add_shapes: true),
255255
saver_def: _saver_def,
256256
collection_list: collection_list,
257257
as_text: as_text,

src/TensorFlowNET.Core/Train/tf.optimizers.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ public static class train
1616

1717
public static Saver Saver() => new Saver();
1818

19-
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);
19+
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
20+
=> graph_io.write_graph(graph, logdir, name, as_text);
2021

2122
public static Saver import_meta_graph(string meta_graph_or_file,
2223
bool clear_devices = false,

test/TensorFlowNET.Examples/LogisticRegression.cs

+49-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using NumSharp.Core;
33
using System;
44
using System.Collections.Generic;
5+
using System.IO;
56
using System.Linq;
67
using System.Text;
78
using Tensorflow;
@@ -17,7 +18,7 @@ namespace TensorFlowNET.Examples
1718
public class LogisticRegression : Python, IExample
1819
{
1920
private float learning_rate = 0.01f;
20-
private int training_epochs = 5;
21+
private int training_epochs = 10;
2122
private int batch_size = 100;
2223
private int display_step = 1;
2324

@@ -78,19 +79,66 @@ public void Run()
7879
}
7980

8081
print("Optimization Finished!");
82+
// SaveModel(sess);
8183

8284
// Test model
8385
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
8486
// Calculate accuracy
8587
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
8688
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
8789
print($"Accuracy: {acc.ToString("F4")}");
90+
91+
Predict();
8892
});
8993
}
9094

9195
public void PrepareData()
9296
{
9397
mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
9498
}
99+
100+
public void SaveModel(Session sess)
101+
{
102+
var saver = tf.train.Saver();
103+
var save_path = saver.save(sess, "logistic_regression/model.ckpt");
104+
tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true);
105+
106+
FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt",
107+
input_saver: "",
108+
input_binary: false,
109+
input_checkpoint: "logistic_regression/model.ckpt",
110+
output_node_names: "Softmax",
111+
restore_op_name: "save/restore_all",
112+
filename_tensor_name: "save/Const:0",
113+
output_graph: "logistic_regression/model.pb",
114+
clear_devices: true,
115+
initializer_nodes: "");
116+
}
117+
118+
public void Predict()
119+
{
120+
var graph = new Graph().as_default();
121+
graph.Import(Path.Join("logistic_regression", "model.pb"));
122+
123+
with(tf.Session(graph), sess =>
124+
{
125+
// restoring the model
126+
// var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta");
127+
// saver.restore(sess, tf.train.latest_checkpoint('logistic_regression'));
128+
var pred = graph.OperationByName("Softmax");
129+
var output = pred.outputs[0];
130+
var x = graph.OperationByName("Placeholder");
131+
var input = x.outputs[0];
132+
133+
// predict
134+
var (batch_xs, batch_ys) = mnist.train.next_batch(10);
135+
var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)]));
136+
137+
if (results.argmax() == (batch_ys[0] as NDArray).argmax())
138+
print("predicted OK!");
139+
else
140+
throw new ValueError("predict error, maybe 90% accuracy");
141+
});
142+
}
95143
}
96144
}

test/TensorFlowNET.Examples/python/logistic_regression.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# Parameters
1818
learning_rate = 0.01
19-
training_epochs = 25
19+
training_epochs = 10
2020
batch_size = 100
2121
display_step = 1
2222

@@ -67,4 +67,34 @@
6767
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
6868
# Calculate accuracy
6969
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
70-
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
70+
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
71+
72+
# predict
73+
# results = sess.run(pred, feed_dict={x: batch_xs[:1]})
74+
75+
# save model
76+
saver = tf.train.Saver()
77+
save_path = saver.save(sess, "logistic_regression/model.ckpt")
78+
tf.train.write_graph(sess.graph.as_graph_def(),'logistic_regression','model.pbtxt', as_text=True)
79+
80+
freeze_graph.freeze_graph(input_graph = 'logistic_regression/model.pbtxt',
81+
input_saver = "",
82+
input_binary = False,
83+
input_checkpoint = 'logistic_regression/model.ckpt',
84+
output_node_names = "Softmax",
85+
restore_op_name = "save/restore_all",
86+
filename_tensor_name = "save/Const:0",
87+
output_graph = 'logistic_regression/model.pb',
88+
clear_devices = True,
89+
initializer_nodes = "")
90+
91+
# restoring the model
92+
saver = tf.train.import_meta_graph('logistic_regression/tensorflowModel.ckpt.meta')
93+
saver.restore(sess,tf.train.latest_checkpoint('logistic_regression'))
94+
95+
# predict
96+
# pred = graph._nodes_by_name["Softmax"]
97+
# output = pred.outputs[0]
98+
# x = graph._nodes_by_name["Placeholder"]
99+
# input = x.outputs[0]
100+
# results = sess.run(output, feed_dict={input: batch_xs[:1]})

0 commit comments

Comments
 (0)