@@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers
13
13
/// </summary>
14
14
public class Bidirectional : Wrapper
15
15
{
16
- BidirectionalArgs _args ;
17
- RNN _forward_layer ;
18
- RNN _backward_layer ;
19
- RNN _layer ;
20
- bool _support_masking = true ;
21
16
int _num_constants = 0 ;
17
+ bool _support_masking = true ;
22
18
bool _return_state ;
23
19
bool _stateful ;
24
20
bool _return_sequences ;
25
- InputSpec _input_spec ;
21
+ BidirectionalArgs _args ;
26
22
RNNArgs _layer_args_copy ;
23
+ RNN _forward_layer ;
24
+ RNN _backward_layer ;
25
+ RNN _layer ;
26
+ InputSpec _input_spec ;
27
27
public Bidirectional ( BidirectionalArgs args ) : base ( args )
28
28
{
29
29
_args = args ;
@@ -66,12 +66,16 @@ public Bidirectional(BidirectionalArgs args):base(args)
66
66
67
67
// Recreate the forward layer from the original layer config, so that it
68
68
// will not carry over any state from the layer.
69
- var actualType = _layer . GetType ( ) ;
70
- if ( actualType == typeof ( LSTM ) )
69
+ if ( _layer is LSTM )
71
70
{
72
71
var arg = _layer_args_copy as LSTMArgs ;
73
72
_forward_layer = new LSTM ( arg ) ;
74
73
}
74
+ else if ( _layer is SimpleRNN )
75
+ {
76
+ var arg = _layer_args_copy as SimpleRNNArgs ;
77
+ _forward_layer = new SimpleRNN ( arg ) ;
78
+ }
75
79
// TODO(Wanglongzhi2001), add GRU if case.
76
80
else
77
81
{
@@ -154,12 +158,18 @@ private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false)
154
158
{
155
159
config . GoBackwards = ! config . GoBackwards ;
156
160
}
157
- var actualType = layer . GetType ( ) ;
158
- if ( actualType == typeof ( LSTM ) )
161
+
162
+ if ( layer is LSTM )
159
163
{
160
164
var arg = config as LSTMArgs ;
161
165
return new LSTM ( arg ) ;
162
166
}
167
+ else if ( layer is SimpleRNN )
168
+ {
169
+ var arg = config as SimpleRNNArgs ;
170
+ return new SimpleRNN ( arg ) ;
171
+ }
172
+ // TODO(Wanglongzhi2001), add GRU if case.
163
173
else
164
174
{
165
175
return new RNN ( cell , config ) ;
@@ -183,7 +193,6 @@ public override void build(KerasShapesWrapper input_shape)
183
193
protected override Tensors Call ( Tensors inputs , Tensors state = null , bool ? training = null , IOptionalArgs ? optional_args = null )
184
194
{
185
195
// `Bidirectional.call` implements the same API as the wrapped `RNN`.
186
-
187
196
Tensors forward_inputs ;
188
197
Tensors backward_inputs ;
189
198
Tensors forward_state ;
0 commit comments