diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs index 6114d9c7c..0566b08ad 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs @@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers /// public class Bidirectional: Wrapper { - BidirectionalArgs _args; - RNN _forward_layer; - RNN _backward_layer; - RNN _layer; - bool _support_masking = true; int _num_constants = 0; + bool _support_masking = true; bool _return_state; bool _stateful; bool _return_sequences; - InputSpec _input_spec; + BidirectionalArgs _args; RNNArgs _layer_args_copy; + RNN _forward_layer; + RNN _backward_layer; + RNN _layer; + InputSpec _input_spec; public Bidirectional(BidirectionalArgs args):base(args) { _args = args; @@ -66,12 +66,16 @@ public Bidirectional(BidirectionalArgs args):base(args) // Recreate the forward layer from the original layer config, so that it // will not carry over any state from the layer. - var actualType = _layer.GetType(); - if (actualType == typeof(LSTM)) + if (_layer is LSTM) { var arg = _layer_args_copy as LSTMArgs; _forward_layer = new LSTM(arg); } + else if(_layer is SimpleRNN) + { + var arg = _layer_args_copy as SimpleRNNArgs; + _forward_layer = new SimpleRNN(arg); + } // TODO(Wanglongzhi2001), add GRU if case. else { @@ -154,12 +158,18 @@ private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false) { config.GoBackwards = !config.GoBackwards; } - var actualType = layer.GetType(); - if (actualType == typeof(LSTM)) + + if (layer is LSTM) { var arg = config as LSTMArgs; return new LSTM(arg); } + else if(layer is SimpleRNN) + { + var arg = config as SimpleRNNArgs; + return new SimpleRNN(arg); + } + // TODO(Wanglongzhi2001), add GRU if case. else { return new RNN(cell, config); @@ -183,7 +193,6 @@ public override void build(KerasShapesWrapper input_shape) protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { // `Bidirectional.call` implements the same API as the wrapped `RNN`. - Tensors forward_inputs; Tensors backward_inputs; Tensors forward_state;