Skip to content

fix: remove the reflection in the implemention of Bidirectional #1162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers
/// </summary>
public class Bidirectional: Wrapper
{
BidirectionalArgs _args;
RNN _forward_layer;
RNN _backward_layer;
RNN _layer;
bool _support_masking = true;
int _num_constants = 0;
bool _support_masking = true;
bool _return_state;
bool _stateful;
bool _return_sequences;
InputSpec _input_spec;
BidirectionalArgs _args;
RNNArgs _layer_args_copy;
RNN _forward_layer;
RNN _backward_layer;
RNN _layer;
InputSpec _input_spec;
public Bidirectional(BidirectionalArgs args):base(args)
{
_args = args;
Expand Down Expand Up @@ -66,12 +66,16 @@ public Bidirectional(BidirectionalArgs args):base(args)

// Recreate the forward layer from the original layer config, so that it
// will not carry over any state from the layer.
var actualType = _layer.GetType();
if (actualType == typeof(LSTM))
if (_layer is LSTM)
{
var arg = _layer_args_copy as LSTMArgs;
_forward_layer = new LSTM(arg);
}
else if(_layer is SimpleRNN)
{
var arg = _layer_args_copy as SimpleRNNArgs;
_forward_layer = new SimpleRNN(arg);
}
// TODO(Wanglongzhi2001), add GRU if case.
else
{
Expand Down Expand Up @@ -154,12 +158,18 @@ private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false)
{
config.GoBackwards = !config.GoBackwards;
}
var actualType = layer.GetType();
if (actualType == typeof(LSTM))

if (layer is LSTM)
{
var arg = config as LSTMArgs;
return new LSTM(arg);
}
else if(layer is SimpleRNN)
{
var arg = config as SimpleRNNArgs;
return new SimpleRNN(arg);
}
// TODO(Wanglongzhi2001), add GRU if case.
else
{
return new RNN(cell, config);
Expand All @@ -183,7 +193,6 @@ public override void build(KerasShapesWrapper input_shape)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// `Bidirectional.call` implements the same API as the wrapped `RNN`.

Tensors forward_inputs;
Tensors backward_inputs;
Tensors forward_state;
Expand Down