@@ -26,14 +26,17 @@ public partial class Model
26
26
/// <param name="workers"></param>
27
27
/// <param name="use_multiprocessing"></param>
28
28
/// <param name="return_dict"></param>
29
+ /// <param name="is_val"></param>
29
30
public Dictionary < string , float > evaluate ( NDArray x , NDArray y ,
30
31
int batch_size = - 1 ,
31
32
int verbose = 1 ,
32
33
int steps = - 1 ,
33
34
int max_queue_size = 10 ,
34
35
int workers = 1 ,
35
36
bool use_multiprocessing = false ,
36
- bool return_dict = false )
37
+ bool return_dict = false ,
38
+ bool is_val = false
39
+ )
37
40
{
38
41
if ( x . dims [ 0 ] != y . dims [ 0 ] )
39
42
{
@@ -63,31 +66,76 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
63
66
} ) ;
64
67
callbacks . on_test_begin ( ) ;
65
68
66
- IEnumerable < ( string , Tensor ) > logs = null ;
69
+ //Dictionary<string, float>? logs = null;
70
+ var logs = new Dictionary < string , float > ( ) ;
67
71
foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
68
72
{
69
73
reset_metrics ( ) ;
70
- callbacks . on_epoch_begin ( epoch ) ;
71
74
// data_handler.catch_stop_iteration();
72
75
73
76
foreach ( var step in data_handler . steps ( ) )
74
77
{
75
78
callbacks . on_test_batch_begin ( step ) ;
76
79
logs = test_function ( data_handler , iterator ) ;
77
80
var end_step = step + data_handler . StepIncrement ;
78
- callbacks . on_test_batch_end ( end_step , logs ) ;
81
+ if ( is_val == false )
82
+ callbacks . on_test_batch_end ( end_step , logs ) ;
79
83
}
80
84
}
81
85
82
86
var results = new Dictionary < string , float > ( ) ;
83
87
foreach ( var log in logs )
84
88
{
85
- results [ log . Item1 ] = ( float ) log . Item2 ;
89
+ results [ log . Key ] = log . Value ;
86
90
}
87
91
return results ;
88
92
}
89
93
90
- public Dictionary < string , float > evaluate ( IDatasetV2 x , int verbose = 1 )
94
+ public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , NDArray y , int verbose = 1 , bool is_val = false )
95
+ {
96
+ var data_handler = new DataHandler ( new DataHandlerArgs
97
+ {
98
+ X = new Tensors ( x ) ,
99
+ Y = y ,
100
+ Model = this ,
101
+ StepsPerExecution = _steps_per_execution
102
+ } ) ;
103
+
104
+ var callbacks = new CallbackList ( new CallbackParams
105
+ {
106
+ Model = this ,
107
+ Verbose = verbose ,
108
+ Steps = data_handler . Inferredsteps
109
+ } ) ;
110
+ callbacks . on_test_begin ( ) ;
111
+
112
+ Dictionary < string , float > logs = null ;
113
+ foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
114
+ {
115
+ reset_metrics ( ) ;
116
+ callbacks . on_epoch_begin ( epoch ) ;
117
+ // data_handler.catch_stop_iteration();
118
+
119
+ foreach ( var step in data_handler . steps ( ) )
120
+ {
121
+ callbacks . on_test_batch_begin ( step ) ;
122
+ logs = test_step_multi_inputs_function ( data_handler , iterator ) ;
123
+ var end_step = step + data_handler . StepIncrement ;
124
+ if ( is_val == false )
125
+ callbacks . on_test_batch_end ( end_step , logs ) ;
126
+ }
127
+ }
128
+
129
+ var results = new Dictionary < string , float > ( ) ;
130
+ foreach ( var log in logs )
131
+ {
132
+ results [ log . Key ] = log . Value ;
133
+ }
134
+ return results ;
135
+ }
136
+
137
+
138
+ public Dictionary < string , float > evaluate ( IDatasetV2 x , int verbose = 1 , bool is_val = false )
91
139
{
92
140
var data_handler = new DataHandler ( new DataHandlerArgs
93
141
{
@@ -104,7 +152,7 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
104
152
} ) ;
105
153
callbacks . on_test_begin ( ) ;
106
154
107
- IEnumerable < ( string , Tensor ) > logs = null ;
155
+ Dictionary < string , float > logs = null ;
108
156
foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
109
157
{
110
158
reset_metrics ( ) ;
@@ -113,36 +161,46 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
113
161
114
162
foreach ( var step in data_handler . steps ( ) )
115
163
{
116
- // callbacks.on_train_batch_begin (step)
164
+ callbacks . on_test_batch_begin ( step ) ;
117
165
logs = test_function ( data_handler , iterator ) ;
166
+ var end_step = step + data_handler . StepIncrement ;
167
+ if ( is_val == false )
168
+ callbacks . on_test_batch_end ( end_step , logs ) ;
118
169
}
119
170
}
120
171
121
172
var results = new Dictionary < string , float > ( ) ;
122
173
foreach ( var log in logs )
123
174
{
124
- results [ log . Item1 ] = ( float ) log . Item2 ;
175
+ results [ log . Key ] = log . Value ;
125
176
}
126
177
return results ;
127
178
}
128
179
129
- IEnumerable < ( string , Tensor ) > test_function ( DataHandler data_handler , OwnedIterator iterator )
180
+ Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
130
181
{
131
182
var data = iterator . next ( ) ;
132
183
var outputs = test_step ( data_handler , data [ 0 ] , data [ 1 ] ) ;
133
184
tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
134
185
return outputs ;
135
186
}
136
-
137
- List < ( string , Tensor ) > test_step ( DataHandler data_handler , Tensor x , Tensor y )
187
+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , OwnedIterator iterator )
188
+ {
189
+ var data = iterator . next ( ) ;
190
+ var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
191
+ var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
192
+ tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _train_counter . assign_add ( 1 ) ) ;
193
+ return outputs ;
194
+ }
195
+ Dictionary < string , float > test_step ( DataHandler data_handler , Tensor x , Tensor y )
138
196
{
139
197
( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
140
198
var y_pred = Apply ( x , training : false ) ;
141
199
var loss = compiled_loss . Call ( y , y_pred ) ;
142
200
143
201
compiled_metrics . update_state ( y , y_pred ) ;
144
202
145
- return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToList ( ) ;
203
+ return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Item1 , x => ( float ) x . Item2 ) ;
146
204
}
147
205
}
148
206
}
0 commit comments