Skip to content

Commit 2adfcd2

Browse files
committed
Fix Sequential model.summary missing layers. #960
1 parent 196401b commit 2adfcd2

File tree

4 files changed

+9
-2
lines changed

4 files changed

+9
-2
lines changed

src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public IEnumerable<ILayer> _flatten_layers(bool recursive = true, bool include_s
1010
yield return this;
1111

1212
var seen_object_ids = new List<int>();
13-
var deque = new Queue<ILayer>(_layers);
13+
var deque = new Queue<ILayer>(_self_tracked_trackables);
1414
while (!deque.empty())
1515
{
1616
var layer_or_container = deque.Dequeue();

src/TensorFlowNET.Keras/Engine/Layer.Layers.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Tensorflow.Keras.Engine
66
public partial class Layer
77
{
88
protected List<ILayer> _layers = new List<ILayer>();
9-
public List<ILayer> Layers => _layers;
9+
public virtual List<ILayer> Layers => _layers;
1010

1111
protected void StackLayers(params ILayer[] layers)
1212
{

src/TensorFlowNET.Keras/Engine/Model.cs

+4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Collections.Generic;
2+
using System.Linq;
23
using Tensorflow.Keras.ArgsDefinition;
34
using Tensorflow.Keras.Engine.DataAdapters;
45
using Tensorflow.Keras.Losses;
@@ -70,6 +71,9 @@ void _init_batch_counters()
7071
aggregation: VariableAggregation.OnlyFirstReplica);
7172
}
7273

74+
public override List<ILayer> Layers
75+
=> _flatten_layers(recursive: false, include_self: false).ToList();
76+
7377
public override List<IVariableV1> trainable_variables
7478
{
7579
get

src/TensorFlowNET.Keras/Engine/Sequential.cs

+3
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,8 @@ void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes)
202202
created_nodes.add(prev_layer.OutboundNodes.Last());
203203
}
204204
}
205+
206+
public override List<ILayer> Layers
207+
=> base.Layers.Where(x => x is not InputLayer).ToList();
205208
}
206209
}

0 commit comments

Comments
 (0)