Skip to content

Commit 70d681c

Browse files
authored
Merge pull request #1168 from Wanglongzhi2001/master
feat: implement GRU layer
2 parents ba1ddb4 + 7b077ea commit 70d681c

File tree

7 files changed

+300
-41
lines changed

7 files changed

+300
-41
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class GRUArgs : AutoSerializeLayerArgs
8+
{
9+
public int Units { get; set; }
10+
public Activation Activation { get; set; }
11+
public Activation RecurrentActivation { get; set; }
12+
public bool UseBias { get; set; } = true;
13+
public float Dropout { get; set; } = .0f;
14+
public float RecurrentDropout { get; set; } = .0f;
15+
public IInitializer KernelInitializer { get; set; }
16+
public IInitializer RecurrentInitializer { get; set; }
17+
public IInitializer BiasInitializer { get; set; }
18+
public bool ReturnSequences { get;set; }
19+
public bool ReturnState { get;set; }
20+
public bool GoBackwards { get;set; }
21+
public bool Stateful { get;set; }
22+
public bool Unroll { get;set; }
23+
public bool TimeMajor { get;set; }
24+
public bool ResetAfter { get;set; }
25+
public int Implementation { get; set; } = 2;
26+
27+
}
28+
29+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class GRUOptionalArgs
8+
{
9+
public string Identifier => "GRU";
10+
11+
public Tensor Mask { get; set; } = null;
12+
}
13+
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,25 @@ public IRnnCell GRUCell(
259259
float recurrent_dropout = 0f,
260260
bool reset_after = true);
261261

262+
public ILayer GRU(
263+
int units,
264+
string activation = "tanh",
265+
string recurrent_activation = "sigmoid",
266+
bool use_bias = true,
267+
string kernel_initializer = "glorot_uniform",
268+
string recurrent_initializer = "orthogonal",
269+
string bias_initializer = "zeros",
270+
float dropout = 0f,
271+
float recurrent_dropout = 0f,
272+
bool return_sequences = false,
273+
bool return_state = false,
274+
bool go_backwards = false,
275+
bool stateful = false,
276+
bool unroll = false,
277+
bool time_major = false,
278+
bool reset_after = true
279+
);
280+
262281
/// <summary>
263282
/// Bidirectional wrapper for RNNs.
264283
/// </summary>

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ public IRnnCell LSTMCell(int uints,
784784
string recurrent_activation = "sigmoid",
785785
bool use_bias = true,
786786
string kernel_initializer = "glorot_uniform",
787-
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
787+
string recurrent_initializer = "orthogonal",
788788
string bias_initializer = "zeros",
789789
bool unit_forget_bias = true,
790790
float dropout = 0f,
@@ -908,6 +908,65 @@ public IRnnCell GRUCell(
908908
ResetAfter = reset_after
909909
});
910910

911+
/// <summary>
912+
/// Gated Recurrent Unit - Cho et al. 2014.
913+
/// </summary>
914+
/// <param name="units">Positive integer, dimensionality of the output space.</param>
915+
/// <param name="activation">Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`).</param>
916+
/// <param name="recurrent_activation">Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`).</param>
917+
/// <param name="use_bias">Boolean, (default `True`), whether the layer uses a bias vector.</param>
918+
/// <param name="kernel_initializer">Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`.</param>
919+
/// <param name="recurrent_initializer">Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`.</param>
920+
/// <param name="bias_initializer">Initializer for the bias vector. Default: `zeros`.</param>
921+
/// <param name="dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.</param>
922+
/// <param name="recurrent_dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.</param>
923+
/// <param name="implementation"></param>
924+
/// <param name="return_sequences">Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`.</param>
925+
/// <param name="return_state">Boolean. Whether to return the last state in addition to the output. Default: `False`.</param>
926+
/// <param name="go_backwards">Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence.</param>
927+
/// <param name="stateful">Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.</param>
928+
/// <param name="unroll">Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN,</param>
929+
/// <param name="time_major">The shape format of the `inputs` and `outputs` tensors.</param>
930+
/// <param name="reset_after">GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible).</param>
931+
/// <returns></returns>
932+
public ILayer GRU(
933+
int units,
934+
string activation = "tanh",
935+
string recurrent_activation = "sigmoid",
936+
bool use_bias = true,
937+
string kernel_initializer = "glorot_uniform",
938+
string recurrent_initializer = "orthogonal",
939+
string bias_initializer = "zeros",
940+
float dropout = 0f,
941+
float recurrent_dropout = 0f,
942+
bool return_sequences = false,
943+
bool return_state = false,
944+
bool go_backwards = false,
945+
bool stateful = false,
946+
bool unroll = false,
947+
bool time_major = false,
948+
bool reset_after = true
949+
)
950+
=> new GRU(new GRUArgs
951+
{
952+
Units = units,
953+
Activation = keras.activations.GetActivationFromName(activation),
954+
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
955+
KernelInitializer = GetInitializerByName(kernel_initializer),
956+
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
957+
BiasInitializer = GetInitializerByName(bias_initializer),
958+
UseBias = use_bias,
959+
Dropout = dropout,
960+
RecurrentDropout = recurrent_dropout,
961+
ReturnSequences = return_sequences,
962+
ReturnState = return_state,
963+
GoBackwards = go_backwards,
964+
Stateful = stateful,
965+
TimeMajor = time_major,
966+
Unroll = unroll,
967+
ResetAfter = reset_after
968+
});
969+
911970
public ILayer Bidirectional(
912971
ILayer layer,
913972
string merge_mode = "concat",
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Common.Extensions;
6+
using Tensorflow.Common.Types;
7+
using Tensorflow.Keras.Saving;
8+
9+
10+
namespace Tensorflow.Keras.Layers
11+
{
12+
public class GRU : RNN
13+
{
14+
GRUArgs _args;
15+
private static GRUCell _cell;
16+
17+
bool _return_runtime;
18+
public GRUCell Cell { get => _cell; }
19+
public int units { get => _args.Units; }
20+
public Activation activation { get => _args.Activation; }
21+
public Activation recurrent_activation { get => _args.RecurrentActivation; }
22+
public bool use_bias { get => _args.UseBias; }
23+
public float dropout { get => _args.Dropout; }
24+
public float recurrent_dropout { get => _args.RecurrentDropout; }
25+
public IInitializer kernel_initializer { get => _args.KernelInitializer; }
26+
public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; }
27+
public IInitializer bias_initializer { get => _args.BiasInitializer; }
28+
public int implementation { get => _args.Implementation; }
29+
public bool reset_after { get => _args.ResetAfter; }
30+
31+
public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args))
32+
{
33+
_args = args;
34+
35+
if (_args.Implementation == 0)
36+
{
37+
// Use the red output to act as a warning message that can also be used under the release version
38+
Console.ForegroundColor = ConsoleColor.Red;
39+
Console.WriteLine("Warning: `implementation=0` has been deprecated, "+
40+
"and now defaults to `implementation=2`."+
41+
"Please update your layer call.");
42+
Console.ResetColor();
43+
}
44+
45+
GRUCell cell = new GRUCell(new GRUCellArgs
46+
{
47+
Units = _args.Units,
48+
Activation = _args.Activation,
49+
RecurrentActivation = _args.RecurrentActivation,
50+
UseBias = _args.UseBias,
51+
Dropout = _args.Dropout,
52+
RecurrentDropout = _args.RecurrentDropout,
53+
KernelInitializer = _args.KernelInitializer,
54+
RecurrentInitializer = _args.RecurrentInitializer,
55+
BiasInitializer = _args.BiasInitializer,
56+
ResetAfter = _args.ResetAfter,
57+
Implementation = _args.Implementation
58+
});
59+
_cell = cell;
60+
}
61+
62+
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
63+
{
64+
GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs;
65+
if (optional_args is not null && gru_optional_args is null)
66+
{
67+
throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`.");
68+
}
69+
Tensors? mask = gru_optional_args?.Mask;
70+
71+
// Not support ragger input temporarily;
72+
int row_length = 0;
73+
bool is_ragged_input = false;
74+
75+
_validate_args_if_ragged(is_ragged_input, mask);
76+
77+
// GRU does not support constants.Ignore it during process.
78+
(inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null);
79+
80+
if (mask.Length > 1)
81+
{
82+
mask = mask[0];
83+
}
84+
85+
var input_shape = inputs.shape;
86+
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
87+
88+
89+
// TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part
90+
Func<Tensors, Tensors, (Tensors, Tensors)> step = (cell_inputs, cell_states) =>
91+
{
92+
var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value);
93+
var (output, state) = res;
94+
return (output, state);
95+
};
96+
97+
var (last_output, outputs, states) = keras.backend.rnn(
98+
step,
99+
inputs,
100+
initial_state,
101+
constants: null,
102+
go_backwards: _args.GoBackwards,
103+
mask: mask,
104+
unroll: _args.Unroll,
105+
input_length: ops.convert_to_tensor(timesteps),
106+
time_major: _args.TimeMajor,
107+
zero_output_for_mask: base.Args.ZeroOutputForMask,
108+
return_all_outputs: _args.ReturnSequences
109+
);
110+
111+
Tensors output;
112+
if (_args.ReturnSequences)
113+
{
114+
output = outputs;
115+
}
116+
else
117+
{
118+
output = last_output;
119+
}
120+
121+
if (_args.ReturnState)
122+
{
123+
output = new Tensors { output, states };
124+
}
125+
return output;
126+
}
127+
128+
private static IRnnCell CreateCell(GRUArgs gruArgs)
129+
{
130+
return new GRUCell(new GRUCellArgs
131+
{
132+
Units = gruArgs.Units,
133+
Activation = gruArgs.Activation,
134+
RecurrentActivation = gruArgs.RecurrentActivation,
135+
UseBias = gruArgs.UseBias,
136+
Dropout = gruArgs.Dropout,
137+
RecurrentDropout = gruArgs.RecurrentDropout,
138+
KernelInitializer = gruArgs.KernelInitializer,
139+
RecurrentInitializer = gruArgs.RecurrentInitializer,
140+
BiasInitializer = gruArgs.BiasInitializer,
141+
ResetAfter = gruArgs.ResetAfter,
142+
Implementation = gruArgs.Implementation
143+
});
144+
}
145+
146+
private static RNNArgs PreConstruct(GRUArgs args)
147+
{
148+
return new RNNArgs
149+
{
150+
ReturnSequences = args.ReturnSequences,
151+
ReturnState = args.ReturnState,
152+
GoBackwards = args.GoBackwards,
153+
Stateful = args.Stateful,
154+
Unroll = args.Unroll,
155+
TimeMajor = args.TimeMajor,
156+
Units = args.Units,
157+
Activation = args.Activation,
158+
RecurrentActivation = args.RecurrentActivation,
159+
UseBias = args.UseBias,
160+
Dropout = args.Dropout,
161+
RecurrentDropout = args.RecurrentDropout,
162+
KernelInitializer = args.KernelInitializer,
163+
RecurrentInitializer = args.RecurrentInitializer,
164+
BiasInitializer = args.BiasInitializer
165+
};
166+
}
167+
}
168+
}

src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ public class RNN : RnnBase
2525
private RNNArgs _args;
2626
private object _input_spec = null; // or NoneValue??
2727
private object _state_spec = null;
28-
private Tensors _states = null;
2928
private object _constants_spec = null;
29+
private Tensors _states = null;
3030
private int _num_constants;
3131
protected IVariableV1 _kernel;
3232
protected IVariableV1 _bias;
@@ -469,7 +469,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo
469469
return (inputs, initial_state, constants);
470470
}
471471

472-
private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
472+
protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
473473
{
474474
if (!is_ragged_input)
475475
{
@@ -528,44 +528,6 @@ public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = n
528528
throw new NotImplementedException();
529529
}
530530

531-
// 好像不能cell不能传接口类型
532-
//public RNN New(IRnnArgCell cell,
533-
// bool return_sequences = false,
534-
// bool return_state = false,
535-
// bool go_backwards = false,
536-
// bool stateful = false,
537-
// bool unroll = false,
538-
// bool time_major = false)
539-
// => new RNN(new RNNArgs
540-
// {
541-
// Cell = cell,
542-
// ReturnSequences = return_sequences,
543-
// ReturnState = return_state,
544-
// GoBackwards = go_backwards,
545-
// Stateful = stateful,
546-
// Unroll = unroll,
547-
// TimeMajor = time_major
548-
// });
549-
550-
//public RNN New(List<IRnnArgCell> cell,
551-
// bool return_sequences = false,
552-
// bool return_state = false,
553-
// bool go_backwards = false,
554-
// bool stateful = false,
555-
// bool unroll = false,
556-
// bool time_major = false)
557-
// => new RNN(new RNNArgs
558-
// {
559-
// Cell = cell,
560-
// ReturnSequences = return_sequences,
561-
// ReturnState = return_state,
562-
// GoBackwards = go_backwards,
563-
// Stateful = stateful,
564-
// Unroll = unroll,
565-
// TimeMajor = time_major
566-
// });
567-
568-
569531
protected Tensors get_initial_state(Tensors inputs)
570532
{
571533
var input = inputs[0];

test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ public void GRUCell()
146146

147147
}
148148

149+
[TestMethod]
150+
public void GRU()
151+
{
152+
var inputs = tf.ones((32, 10, 8));
153+
var gru = tf.keras.layers.GRU(4);
154+
var output = gru.Apply(inputs);
155+
Assert.AreEqual((32, 4), output.shape);
156+
}
157+
149158
[TestMethod]
150159
public void Bidirectional()
151160
{

0 commit comments

Comments
 (0)