Skip to content

Commit adeed05

Browse files
authored
Merge pull request #1162 from Wanglongzhi2001/master
fix: remove the reflection in the implemention of Bidirectional
2 parents 48403a5 + 6d3f134 commit adeed05

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs

+20-11
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers
1313
/// </summary>
1414
public class Bidirectional: Wrapper
1515
{
16-
BidirectionalArgs _args;
17-
RNN _forward_layer;
18-
RNN _backward_layer;
19-
RNN _layer;
20-
bool _support_masking = true;
2116
int _num_constants = 0;
17+
bool _support_masking = true;
2218
bool _return_state;
2319
bool _stateful;
2420
bool _return_sequences;
25-
InputSpec _input_spec;
21+
BidirectionalArgs _args;
2622
RNNArgs _layer_args_copy;
23+
RNN _forward_layer;
24+
RNN _backward_layer;
25+
RNN _layer;
26+
InputSpec _input_spec;
2727
public Bidirectional(BidirectionalArgs args):base(args)
2828
{
2929
_args = args;
@@ -66,12 +66,16 @@ public Bidirectional(BidirectionalArgs args):base(args)
6666

6767
// Recreate the forward layer from the original layer config, so that it
6868
// will not carry over any state from the layer.
69-
var actualType = _layer.GetType();
70-
if (actualType == typeof(LSTM))
69+
if (_layer is LSTM)
7170
{
7271
var arg = _layer_args_copy as LSTMArgs;
7372
_forward_layer = new LSTM(arg);
7473
}
74+
else if(_layer is SimpleRNN)
75+
{
76+
var arg = _layer_args_copy as SimpleRNNArgs;
77+
_forward_layer = new SimpleRNN(arg);
78+
}
7579
// TODO(Wanglongzhi2001), add GRU if case.
7680
else
7781
{
@@ -154,12 +158,18 @@ private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false)
154158
{
155159
config.GoBackwards = !config.GoBackwards;
156160
}
157-
var actualType = layer.GetType();
158-
if (actualType == typeof(LSTM))
161+
162+
if (layer is LSTM)
159163
{
160164
var arg = config as LSTMArgs;
161165
return new LSTM(arg);
162166
}
167+
else if(layer is SimpleRNN)
168+
{
169+
var arg = config as SimpleRNNArgs;
170+
return new SimpleRNN(arg);
171+
}
172+
// TODO(Wanglongzhi2001), add GRU if case.
163173
else
164174
{
165175
return new RNN(cell, config);
@@ -183,7 +193,6 @@ public override void build(KerasShapesWrapper input_shape)
183193
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
184194
{
185195
// `Bidirectional.call` implements the same API as the wrapped `RNN`.
186-
187196
Tensors forward_inputs;
188197
Tensors backward_inputs;
189198
Tensors forward_state;

0 commit comments

Comments
 (0)