2
2
using NumSharp . Core ;
3
3
using System ;
4
4
using System . Collections . Generic ;
5
+ using System . IO ;
5
6
using System . Linq ;
6
7
using System . Text ;
7
8
using Tensorflow ;
@@ -17,7 +18,7 @@ namespace TensorFlowNET.Examples
17
18
public class LogisticRegression : Python , IExample
18
19
{
19
20
private float learning_rate = 0.01f ;
20
- private int training_epochs = 5 ;
21
+ private int training_epochs = 10 ;
21
22
private int batch_size = 100 ;
22
23
private int display_step = 1 ;
23
24
@@ -78,19 +79,66 @@ public void Run()
78
79
}
79
80
80
81
print ( "Optimization Finished!" ) ;
82
+ // SaveModel(sess);
81
83
82
84
// Test model
83
85
var correct_prediction = tf . equal ( tf . argmax ( pred , 1 ) , tf . argmax ( y , 1 ) ) ;
84
86
// Calculate accuracy
85
87
var accuracy = tf . reduce_mean ( tf . cast ( correct_prediction , tf . float32 ) ) ;
86
88
float acc = accuracy . eval ( new FeedItem ( x , mnist . test . images ) , new FeedItem ( y , mnist . test . labels ) ) ;
87
89
print ( $ "Accuracy: { acc . ToString ( "F4" ) } ") ;
90
+
91
+ Predict ( ) ;
88
92
} ) ;
89
93
}
90
94
91
95
public void PrepareData ( )
92
96
{
93
97
mnist = MnistDataSet . read_data_sets ( "logistic_regression" , one_hot : true ) ;
94
98
}
99
+
100
+ public void SaveModel ( Session sess )
101
+ {
102
+ var saver = tf . train . Saver ( ) ;
103
+ var save_path = saver . save ( sess , "logistic_regression/model.ckpt" ) ;
104
+ tf . train . write_graph ( sess . graph , "logistic_regression" , "model.pbtxt" , as_text : true ) ;
105
+
106
+ FreezeGraph . freeze_graph ( input_graph : "logistic_regression/model.pbtxt" ,
107
+ input_saver : "" ,
108
+ input_binary : false ,
109
+ input_checkpoint : "logistic_regression/model.ckpt" ,
110
+ output_node_names : "Softmax" ,
111
+ restore_op_name : "save/restore_all" ,
112
+ filename_tensor_name : "save/Const:0" ,
113
+ output_graph : "logistic_regression/model.pb" ,
114
+ clear_devices : true ,
115
+ initializer_nodes : "" ) ;
116
+ }
117
+
118
+ public void Predict ( )
119
+ {
120
+ var graph = new Graph ( ) . as_default ( ) ;
121
+ graph . Import ( Path . Join ( "logistic_regression" , "model.pb" ) ) ;
122
+
123
+ with ( tf . Session ( graph ) , sess =>
124
+ {
125
+ // restoring the model
126
+ // var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta");
127
+ // saver.restore(sess, tf.train.latest_checkpoint('logistic_regression'));
128
+ var pred = graph . OperationByName ( "Softmax" ) ;
129
+ var output = pred . outputs [ 0 ] ;
130
+ var x = graph . OperationByName ( "Placeholder" ) ;
131
+ var input = x . outputs [ 0 ] ;
132
+
133
+ // predict
134
+ var ( batch_xs , batch_ys ) = mnist . train . next_batch ( 10 ) ;
135
+ var results = sess . run ( output , new FeedItem ( input , batch_xs [ np . arange ( 1 ) ] ) ) ;
136
+
137
+ if ( results . argmax ( ) == ( batch_ys [ 0 ] as NDArray ) . argmax ( ) )
138
+ print ( "predicted OK!" ) ;
139
+ else
140
+ throw new ValueError ( "predict error, maybe 90% accuracy" ) ;
141
+ } ) ;
142
+ }
95
143
}
96
144
}
0 commit comments