diff --git a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs index 296c32acb..530a93687 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs @@ -4,6 +4,7 @@ public interface ICallback { Dictionary> history { get; set; } void on_train_begin(); + void on_train_end(); void on_epoch_begin(int epoch); void on_train_batch_begin(long step); void on_train_batch_end(long end_step, Dictionary logs); diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index a462a68eb..cb36f6334 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -17,6 +17,7 @@ ICallback fit(NDArray x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, + List callbacks = null, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, @@ -28,6 +29,7 @@ ICallback fit(IEnumerable x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, + List callbacks = null, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, @@ -73,4 +75,6 @@ Tensors predict(Tensors x, void summary(int line_length = -1, float[] positions = null); IKerasConfig get_config(); + + void set_stopTraining_true(); } diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs index ddb1fa7f5..a28477982 100644 --- a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs @@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks; public class CallbackList { - List callbacks = new List(); + // 改成public使得新定义的callback可以加入到callbacks里 + public List callbacks = new List(); public History History => callbacks[0] as History; public CallbackList(CallbackParams parameters) @@ -66,7 +67,7 @@ public void on_predict_end() public void on_test_batch_begin(long step) { - callbacks.ForEach(x => x.on_train_batch_begin(step)); + callbacks.ForEach(x => x.on_test_batch_begin(step)); } public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs new file mode 100644 index 000000000..1e0418dc5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -0,0 +1,155 @@ +using Tensorflow.Keras.Engine; +namespace Tensorflow.Keras.Callbacks; + + +/// +/// Stop training when a monitored metric has stopped improving. +/// +/// +/// + +public class EarlyStopping: ICallback +{ + int _paitence; + int _min_delta; + int _verbose; + int _stopped_epoch; + int _wait; + int _best_epoch; + int _start_from_epoch; + float _best; + float _baseline; + string _monitor; + string _mode; + bool _restore_best_weights; + List? _best_weights; + CallbackParams _parameters; + public Dictionary>? history { get; set; } + // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model + public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0, + int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false, + int start_from_epoch = 0) + { + _parameters = parameters; + _stopped_epoch = 0; + _wait = 0; + _monitor = monitor; + _paitence = patience; + _verbose = verbose; + _baseline = baseline; + _start_from_epoch = start_from_epoch; + _min_delta = Math.Abs(min_delta); + _restore_best_weights = restore_best_weights; + _mode = mode; + if (mode != "auto" && mode != "min" && mode != "max") + { + Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); + } + } + public void on_train_begin() + { + _wait = 0; + _stopped_epoch = 0; + _best_epoch = 0; + _best = (float)np.Inf; + } + + public void on_epoch_begin(int epoch) + { + + } + + public void on_train_batch_begin(long step) + { + + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + var current = get_monitor_value(epoch_logs); + // If no monitor value exists or still in initial warm-up stage. + if (current == 0f || epoch < _start_from_epoch) + return; + // Restore the weights after first epoch if no progress is ever made. + if (_restore_best_weights && _best_weights == null) + { + _best_weights = _parameters.Model.TrainableWeights; + } + _wait += 1; + + if (_is_improvement(current, _best)) + { + _best = current; + _best_epoch = epoch; + if (_restore_best_weights) + _best_weights = _parameters.Model.TrainableWeights; + // Only restart wait if we beat both the baseline and our previous best. + if (_baseline == 0f || _is_improvement(current, _baseline)) + _wait = 0; + } + // Only check after the first epoch. + if (_wait >= _paitence && epoch > 0) + { + _stopped_epoch = epoch; + _parameters.Model.set_stopTraining_true(); + if (_restore_best_weights && _best_weights != null) + { + if (_verbose > 0) + { + Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); + } + } + // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. + // TODO(Wanglongzhi2001): implement it. + // _parameters.Model.load_weights(best_weights); + } + } + public void on_train_end() + { + if (_stopped_epoch > 0 && _verbose > 0) + { + Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping"); + } + } + public void on_predict_begin() { } + public void on_predict_batch_begin(long step) { } + public void on_predict_batch_end(long end_step, Dictionary logs) { } + public void on_predict_end() { } + public void on_test_begin() { } + public void on_test_batch_begin(long step) { } + public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } + + float get_monitor_value(Dictionary logs) + { + logs = logs ?? new Dictionary(); + float monitor_value = logs[_monitor]; + if (monitor_value == 0f) + { + Console.WriteLine($"Early stopping conditioned on metric {_monitor} " + + $"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}"); + } + return monitor_value; + } + public bool _is_improvement(float monitor_value, float reference_value) + { + bool less_op = (monitor_value - _min_delta) < reference_value; + bool greater_op = (monitor_value - _min_delta) >= reference_value; + if (_mode == "min") + return less_op; + else if (_mode == "max") + return greater_op; + else + { + if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) + { + return greater_op; + } + else + return less_op; + } + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/History.cs b/src/TensorFlowNET.Keras/Callbacks/History.cs index b2d3604e2..d61132612 100644 --- a/src/TensorFlowNET.Keras/Callbacks/History.cs +++ b/src/TensorFlowNET.Keras/Callbacks/History.cs @@ -23,6 +23,7 @@ public void on_test_begin() epochs = new List(); history = new Dictionary>(); } + public void on_train_end() { } public void on_epoch_begin(int epoch) { diff --git a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs index 6462d3d97..d22c779fb 100644 --- a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs +++ b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs @@ -22,6 +22,7 @@ public void on_train_begin() _called_in_fit = true; _sw = new Stopwatch(); } + public void on_train_end() { } public void on_test_begin() { _sw = new Stopwatch(); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 39004183b..7ad4d3ef7 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -19,6 +19,7 @@ public partial class Model /// /// /// + /// /// /// /// @@ -26,6 +27,7 @@ public ICallback fit(NDArray x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, + List callbacks = null, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, @@ -59,7 +61,7 @@ public ICallback fit(NDArray x, NDArray y, StepsPerExecution = _steps_per_execution }); - return FitInternal(data_handler, epochs, verbose, validation_data: null, + return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, train_step_func: train_step_function); } @@ -67,6 +69,7 @@ public ICallback fit(IEnumerable x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, + List callbacks = null, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, @@ -107,12 +110,12 @@ public ICallback fit(IEnumerable x, NDArray y, if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) { - return FitInternal(data_handler, epochs, verbose, validation_data: null, + return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, train_step_func: train_step_multi_inputs_function); } else { - return FitInternal(data_handler, epochs, verbose, validation_data: null, + return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, train_step_func: train_step_function); } } @@ -122,6 +125,7 @@ public History fit(IDatasetV2 dataset, int batch_size = -1, int epochs = 1, int verbose = 1, + List callbacks = null, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, @@ -143,11 +147,11 @@ public History fit(IDatasetV2 dataset, StepsPerExecution = _steps_per_execution }); - return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, + return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, train_step_func: train_step_function); } - History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, + History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, IDatasetV2 validation_data, Func> train_step_func) { stop_training = false; @@ -159,6 +163,13 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV Epochs = epochs, Steps = data_handler.Inferredsteps }); + + if (callbackList != null) + { + foreach(var callback in callbackList) + callbacks.callbacks.add(callback); + } + callbacks.on_train_begin(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 5b3cdbffc..c1d29f592 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -144,5 +144,11 @@ public override IDictionary _trackable_children(SaveType save var children = base._trackable_children(save_type, cache); return children; } + + + void IModel.set_stopTraining_true() + { + stop_training = true; + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs new file mode 100644 index 000000000..636b424f5 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs @@ -0,0 +1,65 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.UnitTest.Helpers; +using static Tensorflow.Binding; +using Tensorflow; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Callbacks; +using Tensorflow.Keras.Engine; +using System.Collections.Generic; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; + + +namespace TensorFlowNET.Keras.UnitTest +{ + [TestClass] + public class EarltstoppingTest + { + [TestMethod] + // Because loading the weight variable into the model has not yet been implemented, + // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. + public void Earltstopping() + { + var layers = keras.layers; + var model = keras.Sequential(new List + { + layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), + layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), + layers.MaxPooling2D(), + layers.Flatten(), + layers.Dense(128, activation: keras.activations.Relu), + layers.Dense(10) + }); + + + model.summary(); + + model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), + loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), + metrics: new[] { "acc" }); + + var num_epochs = 3; + var batch_size = 8; + + var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); + x_train = x_train / 255.0f; + // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. + CallbackParams callback_parameters = new CallbackParams + { + Model = model, + Epochs = num_epochs, + }; + // define your earlystop + ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); + // define a callbcaklist, then add the earlystopping to it. + var callbacks = new List(); + callbacks.add(earlystop); + + model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks); + } + + } + + +} +