|
4 | 4 | using System.Linq; |
5 | 5 | using Tensorflow.Keras.ArgsDefinition; |
6 | 6 | using Tensorflow.Keras.Engine.DataAdapters; |
| 7 | +using System.Diagnostics; |
7 | 8 |
|
8 | 9 | namespace Tensorflow.Keras.Engine |
9 | 10 | { |
@@ -87,25 +88,57 @@ void FitInternal(int epochs, int verbose) |
87 | 88 | { |
88 | 89 | stop_training = false; |
89 | 90 | _train_counter.assign(0); |
| 91 | + Stopwatch sw = new Stopwatch(); |
90 | 92 | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
91 | 93 | { |
92 | 94 | reset_metrics(); |
93 | | - // callbacks.on_epoch_begin(epoch) |
| 95 | + on_epoch_begin(epoch, epochs); |
94 | 96 | // data_handler.catch_stop_iteration(); |
95 | 97 | foreach (var step in data_handler.steps()) |
96 | 98 | { |
97 | | - // callbacks.on_train_batch_begin(step) |
| 99 | + sw.Start(); |
98 | 100 | var results = train_step_function(iterator); |
99 | | - if (verbose == 1) |
| 101 | + sw.Stop(); |
| 102 | + on_train_batch_begin(verbose, step, sw.ElapsedMilliseconds, results); |
| 103 | + |
| 104 | + // recycle memory more frequency |
| 105 | + if (sw.ElapsedMilliseconds > 100) |
100 | 106 | { |
101 | | - var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); |
102 | | - Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); |
| 107 | + GC.Collect(); |
103 | 108 | } |
104 | | - |
105 | | - GC.Collect(); |
| 109 | + sw.Reset(); |
106 | 110 | } |
| 111 | + Console.WriteLine(); |
| 112 | + |
| 113 | + GC.Collect(); |
107 | 114 | GC.WaitForPendingFinalizers(); |
108 | 115 | } |
109 | 116 | } |
| 117 | + |
| 118 | + void on_epoch_begin(int epoch, int epochs) |
| 119 | + { |
| 120 | + Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}"); |
| 121 | + } |
| 122 | + |
| 123 | + void on_train_batch_begin(int verbose, long step, long elapse, IEnumerable<(string, Tensor)> results) |
| 124 | + { |
| 125 | + if (verbose == 1) |
| 126 | + { |
| 127 | + var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); |
| 128 | + |
| 129 | + var progress = ""; |
| 130 | + for (int i = 0; i < step + 1; i++) |
| 131 | + for (int j = 0; j < 30 / data_handler.Inferredsteps; j++) |
| 132 | + progress += "="; |
| 133 | + progress += ">"; |
| 134 | + |
| 135 | + var remaining = ""; |
| 136 | + for (int i = 1; i < 30 - progress.Length; i++) |
| 137 | + remaining += "."; |
| 138 | + |
| 139 | + Binding.tf_output_redirect.Write($"{step + 1:D4}/{data_handler.Inferredsteps:D4} [{progress}{remaining}] - {elapse}ms/step {result_pairs}"); |
| 140 | + Console.CursorLeft = 0; |
| 141 | + } |
| 142 | + } |
110 | 143 | } |
111 | 144 | } |
0 commit comments