From fff5029b0240c30ee4f2b9329c71c8665e091858 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 27 Jun 2023 23:57:48 +0800 Subject: [PATCH 1/2] fix: revise earlystopping callback's min_delta parameter --- src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 73ccc87b0..59152d9b2 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Callbacks; public class EarlyStopping: ICallback { int _paitence; - int _min_delta; + float _min_delta; int _verbose; int _stopped_epoch; int _wait; @@ -26,7 +26,7 @@ public class EarlyStopping: ICallback CallbackParams _parameters; public Dictionary>? history { get; set; } // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model - public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0, + public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false, int start_from_epoch = 0) { From 81b10e37809d5ae57989f55d4102a0a367d4322c Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Wed, 28 Jun 2023 02:19:28 +0800 Subject: [PATCH 2/2] feat: Add GRUCell layer --- src/TensorFlowNET.Core/APIs/tf.tensor.cs | 13 +- .../Keras/ArgsDefinition/Rnn/GRUCellArgs.cs | 39 +++ .../Keras/Layers/ILayersApi.cs | 12 + src/TensorFlowNET.Keras/Layers/LayersApi.cs | 43 +++ src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs | 282 ++++++++++++++++++ .../Layers/Rnn.Test.cs | 13 + 6 files changed, 399 insertions(+), 3 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index 45aebc0cd..b03168ab3 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -68,20 +68,27 @@ public Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides = n /// A name for the operation (optional) /// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value. - public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) + public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) => array_ops.split( value: value, num_or_size_splits: num_split, axis: axis, name: name); - public Tensor[] split(Tensor value, int num_split, int axis, string name = null) + public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null) => array_ops.split( value: value, num_or_size_splits: num_split, - axis: ops.convert_to_tensor(axis), + axis: axis, name: name); + //public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) + // => array_ops.split( + // value: value, + // num_or_size_splits: num_split, + // axis: axis, + // name: name); + public Tensor ensure_shape(Tensor x, Shape shape, string name = null) { return gen_ops.ensure_shape(x, shape, name); diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs new file mode 100644 index 000000000..75d5d0218 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs @@ -0,0 +1,39 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + public class GRUCellArgs : AutoSerializeLayerArgs + { + [JsonProperty("units")] + public int Units { get; set; } + // TODO(Rinne): lack of initialized value of Activation. Merging keras + // into tf.net could resolve it. + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("recurrent_activation")] + public Activation RecurrentActivation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("recurrent_initializer")] + public IInitializer RecurrentInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + [JsonProperty("reset_after")] + public bool ResetAfter { get;set; } + [JsonProperty("implementation")] + public int Implementation { get; set; } = 2; + + + + } + +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index a19508d42..9bc99701d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -246,6 +246,18 @@ public ILayer RNN( bool time_major = false ); + public IRnnCell GRUCell( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool reset_after = true); + public ILayer Subtract(); } } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 0bdcbc841..d20803375 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -873,6 +873,45 @@ public ILayer LSTM(int units, UnitForgetBias = unit_forget_bias }); + /// + /// Cell class for the GRU layer. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public IRnnCell GRUCell( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool reset_after = true) + => new GRUCell(new GRUCellArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + UseBias = use_bias, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + ResetAfter = reset_after + }); + /// /// /// @@ -983,5 +1022,9 @@ public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? me Variance = variance, Invert = invert }); + + + + } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs new file mode 100644 index 000000000..02fe54f49 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs @@ -0,0 +1,282 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers.Rnn +{ + /// + /// Cell class for the GRU layer. + /// + public class GRUCell : DropoutRNNCellMixin + { + GRUCellArgs _args; + IVariableV1 _kernel; + IVariableV1 _recurrent_kernel; + IInitializer _bias_initializer; + IVariableV1 _bias; + INestStructure _state_size; + INestStructure _output_size; + int Units; + public override INestStructure StateSize => _state_size; + + public override INestStructure OutputSize => _output_size; + + public override bool SupportOptionalArgs => false; + public GRUCell(GRUCellArgs args) : base(args) + { + _args = args; + if (_args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + _args.Dropout = Math.Min(1f, Math.Max(0f, _args.Dropout)); + _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); + if (_args.RecurrentDropout != 0f && _args.Implementation != 1) + { + Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + + "Using `implementation=1`."); + _args.Implementation = 1; + } + Units = _args.Units; + _state_size = new NestList(Units); + _output_size = new NestNode(Units); + } + + public override void build(KerasShapesWrapper input_shape) + { + //base.build(input_shape); + + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; + + _kernel = add_weight("kernel", (input_dim, _args.Units * 3), + initializer: _args.KernelInitializer + ); + + _recurrent_kernel = add_weight("recurrent_kernel", (Units, Units * 3), + initializer: _args.RecurrentInitializer + ); + if (_args.UseBias) + { + Shape bias_shape; + if (!_args.ResetAfter) + { + bias_shape = new Shape(3 * Units); + } + else + { + bias_shape = (2, 3 * Units); + } + _bias = add_weight("bias", bias_shape, + initializer: _bias_initializer + ); + } + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var h_tm1 = states.IsNested() ? states[0] : states.Single(); + var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 3); + var rec_dp_mask = get_recurrent_dropout_mask_for_cell(h_tm1, training.Value, count: 3); + + IVariableV1 input_bias = _bias; + IVariableV1 recurrent_bias = _bias; + if (_args.UseBias) + { + if (!_args.ResetAfter) + { + input_bias = _bias; + recurrent_bias = null; + } + else + { + input_bias = tf.Variable(tf.unstack(_bias.AsTensor())[0]); + recurrent_bias = tf.Variable(tf.unstack(_bias.AsTensor())[1]); + } + } + + + Tensor hh; + Tensor z; + if ( _args.Implementation == 1) + { + Tensor inputs_z; + Tensor inputs_r; + Tensor inputs_h; + if (0f < _args.Dropout && _args.Dropout < 1f) + { + inputs_z = inputs * dp_mask[0]; + inputs_r = inputs * dp_mask[1]; + inputs_h = inputs * dp_mask[2]; + } + else + { + inputs_z = inputs.Single(); + inputs_r = inputs.Single(); + inputs_h = inputs.Single(); + } + + + int startIndex = (int)_kernel.AsTensor().shape[0]; + var _kernel_slice = tf.slice(_kernel.AsTensor(), + new[] { 0, 0 }, new[] { startIndex, Units }); + var x_z = math_ops.matmul(inputs_z, _kernel_slice); + _kernel_slice = tf.slice(_kernel.AsTensor(), + new[] { 0, Units }, new[] { Units, Units}); + var x_r = math_ops.matmul( + inputs_r, _kernel_slice); + int endIndex = (int)_kernel.AsTensor().shape[1]; + _kernel_slice = tf.slice(_kernel.AsTensor(), + new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); + var x_h = math_ops.matmul(inputs_h, _kernel_slice); + + if(_args.UseBias) + { + x_z = tf.nn.bias_add( + x_z, tf.Variable(input_bias.AsTensor()[$":{Units}"])); + x_r = tf.nn.bias_add( + x_r, tf.Variable(input_bias.AsTensor()[$"{Units}:{Units * 2}"])); + x_h = tf.nn.bias_add( + x_h, tf.Variable(input_bias.AsTensor()[$"{Units * 2}:"])); + } + + Tensor h_tm1_z; + Tensor h_tm1_r; + Tensor h_tm1_h; + if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) + { + h_tm1_z = h_tm1 * rec_dp_mask[0]; + h_tm1_r = h_tm1 * rec_dp_mask[1]; + h_tm1_h = h_tm1 * rec_dp_mask[2]; + } + else + { + h_tm1_z = h_tm1; + h_tm1_r = h_tm1; + h_tm1_h = h_tm1; + } + + startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, 0 }, new[] { startIndex, Units }); + var recurrent_z = math_ops.matmul( + h_tm1_z, _recurrent_kernel_slice); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, Units }, new[] { startIndex, Units}); + var recurrent_r = math_ops.matmul( + h_tm1_r, _recurrent_kernel_slice); + if(_args.ResetAfter && _args.UseBias) + { + recurrent_z = tf.nn.bias_add( + recurrent_z, tf.Variable(recurrent_bias.AsTensor()[$":{Units}"])); + recurrent_r = tf.nn.bias_add( + recurrent_r, tf.Variable(recurrent_bias.AsTensor()[$"{Units}: {Units * 2}"])); + } + z = _args.RecurrentActivation.Apply(x_z + recurrent_z); + var r = _args.RecurrentActivation.Apply(x_r + recurrent_r); + + Tensor recurrent_h; + if (_args.ResetAfter) + { + endIndex = (int)_recurrent_kernel.AsTensor().shape[1]; + _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); + recurrent_h = math_ops.matmul( + h_tm1_h, _recurrent_kernel_slice); + if(_args.UseBias) + { + recurrent_h = tf.nn.bias_add( + recurrent_h, tf.Variable(recurrent_bias.AsTensor()[$"{Units * 2}:"])); + } + recurrent_h *= r; + } + else + { + _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); + recurrent_h = math_ops.matmul( + r * h_tm1_h, _recurrent_kernel_slice); + } + hh = _args.Activation.Apply(x_h + recurrent_h); + } + else + { + if (0f < _args.Dropout && _args.Dropout < 1f) + { + inputs = inputs * dp_mask[0]; + } + + var matrix_x = math_ops.matmul(inputs, _kernel.AsTensor()); + if(_args.UseBias) + { + matrix_x = tf.nn.bias_add(matrix_x, input_bias); + } + var matrix_x_spilted = tf.split(matrix_x, 3, axis: -1); + var x_z = matrix_x_spilted[0]; + var x_r = matrix_x_spilted[1]; + var x_h = matrix_x_spilted[2]; + + Tensor matrix_inner; + if (_args.ResetAfter) + { + matrix_inner = math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); + if ( _args.UseBias) + { + matrix_inner = tf.nn.bias_add( + matrix_inner, recurrent_bias); + } + } + else + { + var startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, 0 }, new[] { startIndex, Units * 2 }); + matrix_inner = math_ops.matmul( + h_tm1, _recurrent_kernel_slice); + } + + var matrix_inner_splitted = tf.split(matrix_inner, new int[] {Units, Units, -1}, axis:-1); + var recurrent_z = matrix_inner_splitted[0]; + var recurrent_r = matrix_inner_splitted[0]; + var recurrent_h = matrix_inner_splitted[0]; + + z = _args.RecurrentActivation.Apply(x_z + recurrent_z); + var r = _args.RecurrentActivation.Apply(x_r + recurrent_r); + + if(_args.ResetAfter) + { + recurrent_h = r * recurrent_h; + } + else + { + var startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; + var endIndex = (int)_recurrent_kernel.AsTensor().shape[1]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, 2*Units }, new[] { startIndex, endIndex - 2 * Units }); + recurrent_h = math_ops.matmul( + r * h_tm1, _recurrent_kernel_slice); + } + hh = _args.Activation.Apply(x_h + recurrent_h); + } + var h = z * h_tm1 + (1 - z) * hh; + if (states.IsNested()) + { + var new_state = new NestList(h); + return new Nest(new INestStructure[] { new NestNode(h), new_state }).ToTensors(); + } + else + { + return new Nest(new INestStructure[] { new NestNode(h), new NestNode(h)}).ToTensors(); + } + + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index 8eeee7a88..becdbcd60 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -132,5 +132,18 @@ public void RNNForLSTMCell() Console.WriteLine($"output: {output}"); Assert.AreEqual((5, 4), output.shape); } + + [TestMethod] + public void GRUCell() + { + var inputs = tf.random.normal((32, 10, 8)); + var rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4)); + var output = rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); + rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4, reset_after:false, use_bias:false)); + output = rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); + + } } }