From 2eef1b812a00228a153b292ce5fe8c058b1c7fff Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Sun, 9 Apr 2023 21:56:38 +0800 Subject: [PATCH] Finish EarlyStopping --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 2 +- src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs | 2 ++ src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 55409df36..9d69d5d0b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -17,7 +17,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } - List Weights { get; set} + List Weights { get; set; } Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 87b595b64..bc4daf13f 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -84,6 +84,8 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell protected bool built = false; public bool Built => built; + List ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index cba621fae..73ccc87b0 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -102,8 +102,8 @@ public void on_epoch_end(int epoch, Dictionary epoch_logs) { Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); } + _parameters.Model.Weights = _best_weights; } - _parameters.Model.Weights = _best_weights; } } public void on_train_end()