Skip to content

Commit 14da379

Browse files
authored
Merge pull request #1020 from Wanglongzhi2001/master
Finish EarlyStopping
2 parents 33333df + 936c0b4 commit 14da379

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1717
List<IVariableV1> TrainableVariables { get; }
1818
List<IVariableV1> TrainableWeights { get; }
1919
List<IVariableV1> NonTrainableWeights { get; }
20-
List<IVariableV1> Weights { get; }
20+
List<IVariableV1> Weights { get; set}
2121
Shape OutputShape { get; }
2222
Shape BatchInputShape { get; }
2323
TensorShapeConfig BuildInputShape { get; }

src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+2-4
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
7777
// Restore the weights after first epoch if no progress is ever made.
7878
if (_restore_best_weights && _best_weights == null)
7979
{
80-
_best_weights = _parameters.Model.TrainableWeights;
80+
_best_weights = _parameters.Model.Weights;
8181
}
8282
_wait += 1;
8383

@@ -103,9 +103,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
103103
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
104104
}
105105
}
106-
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
107-
// TODO(Wanglongzhi2001): implement it.
108-
// _parameters.Model.load_weights(best_weights);
106+
_parameters.Model.Weights = _best_weights;
109107
}
110108
}
111109
public void on_train_end()

0 commit comments

Comments
 (0)