|
4 | 4 | using System.Linq;
|
5 | 5 | using Tensorflow.Keras.ArgsDefinition;
|
6 | 6 | using Tensorflow.Keras.Engine.DataAdapters;
|
| 7 | +using System.Diagnostics; |
7 | 8 |
|
8 | 9 | namespace Tensorflow.Keras.Engine
|
9 | 10 | {
|
@@ -87,25 +88,57 @@ void FitInternal(int epochs, int verbose)
|
87 | 88 | {
|
88 | 89 | stop_training = false;
|
89 | 90 | _train_counter.assign(0);
|
| 91 | + Stopwatch sw = new Stopwatch(); |
90 | 92 | foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
|
91 | 93 | {
|
92 | 94 | reset_metrics();
|
93 |
| - // callbacks.on_epoch_begin(epoch) |
| 95 | + on_epoch_begin(epoch, epochs); |
94 | 96 | // data_handler.catch_stop_iteration();
|
95 | 97 | foreach (var step in data_handler.steps())
|
96 | 98 | {
|
97 |
| - // callbacks.on_train_batch_begin(step) |
| 99 | + sw.Start(); |
98 | 100 | var results = train_step_function(iterator);
|
99 |
| - if (verbose == 1) |
| 101 | + sw.Stop(); |
| 102 | + on_train_batch_begin(verbose, step, sw.ElapsedMilliseconds, results); |
| 103 | + |
| 104 | + // recycle memory more frequency |
| 105 | + if (sw.ElapsedMilliseconds > 100) |
100 | 106 | {
|
101 |
| - var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); |
102 |
| - Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); |
| 107 | + GC.Collect(); |
103 | 108 | }
|
104 |
| - |
105 |
| - GC.Collect(); |
| 109 | + sw.Reset(); |
106 | 110 | }
|
| 111 | + Console.WriteLine(); |
| 112 | + |
| 113 | + GC.Collect(); |
107 | 114 | GC.WaitForPendingFinalizers();
|
108 | 115 | }
|
109 | 116 | }
|
| 117 | + |
| 118 | + void on_epoch_begin(int epoch, int epochs) |
| 119 | + { |
| 120 | + Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}"); |
| 121 | + } |
| 122 | + |
| 123 | + void on_train_batch_begin(int verbose, long step, long elapse, IEnumerable<(string, Tensor)> results) |
| 124 | + { |
| 125 | + if (verbose == 1) |
| 126 | + { |
| 127 | + var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); |
| 128 | + |
| 129 | + var progress = ""; |
| 130 | + for (int i = 0; i < step + 1; i++) |
| 131 | + for (int j = 0; j < 30 / data_handler.Inferredsteps; j++) |
| 132 | + progress += "="; |
| 133 | + progress += ">"; |
| 134 | + |
| 135 | + var remaining = ""; |
| 136 | + for (int i = 1; i < 30 - progress.Length; i++) |
| 137 | + remaining += "."; |
| 138 | + |
| 139 | + Binding.tf_output_redirect.Write($"{step + 1:D4}/{data_handler.Inferredsteps:D4} [{progress}{remaining}] - {elapse}ms/step {result_pairs}"); |
| 140 | + Console.CursorLeft = 0; |
| 141 | + } |
| 142 | + } |
110 | 143 | }
|
111 | 144 | }
|
0 commit comments