Skip to content

Commit 9deda78

Browse files
committed
Revise customized json converters.
1 parent ecb4d23 commit 9deda78

File tree

6 files changed

+50
-10
lines changed

6 files changed

+50
-10
lines changed

src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,16 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
3737

3838
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
3939
{
40-
var axis = serializer.Deserialize(reader, typeof(int[]));
40+
int[]? axis;
41+
if(reader.ValueType == typeof(long))
42+
{
43+
axis = new int[1];
44+
axis[0] = (int)serializer.Deserialize(reader, typeof(int));
45+
}
46+
else
47+
{
48+
axis = serializer.Deserialize(reader, typeof(int[])) as int[];
49+
}
4150
if (axis is null)
4251
{
4352
throw new ValueError("Cannot deserialize 'null' to `Axis`.");

src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,26 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
5151

5252
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
5353
{
54-
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
55-
if(dims is null)
54+
long?[] dims;
55+
try
56+
{
57+
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
58+
}
59+
catch (JsonSerializationException ex)
60+
{
61+
if (reader.Value.Equals("class_name"))
62+
{
63+
reader.Read();
64+
reader.Read();
65+
reader.Read();
66+
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
67+
}
68+
else
69+
{
70+
throw ex;
71+
}
72+
}
73+
if (dims is null)
5674
{
5775
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
5876
}

src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using static Tensorflow.Binding;
1212
using System.Runtime.CompilerServices;
1313
using Tensorflow.Variables;
14+
using Tensorflow.Functions;
1415

1516
namespace Tensorflow
1617
{

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,14 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs)
7575
this.inputs = inputs;
7676
this.outputs = outputs;
7777
built = true;
78-
_buildInputShape = inputs.shape;
78+
if(inputs.Length > 0)
79+
{
80+
_buildInputShape = inputs.shape;
81+
}
82+
else
83+
{
84+
_buildInputShape = new Saving.TensorShapeConfig();
85+
}
7986

8087
if (outputs.Any(x => x.KerasHistory == null))
8188
base_layer_utils.create_keras_history(outputs);

src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ public void load_layers(bool compile = true)
7272
{
7373
try
7474
{
75+
if (node_metadata.Identifier.Equals("_tf_keras_metric"))
76+
{
77+
continue;
78+
}
7579
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier,
7680
node_metadata.Metadata);
7781
}
@@ -324,7 +328,9 @@ private void _unblock_model_reconstruction(int layer_id, Layer layer)
324328
Trackable obj;
325329
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER)
326330
{
327-
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
331+
// TODO(Rinne): implement it.
332+
return (null, null);
333+
//throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
328334
}
329335
else
330336
{
@@ -343,7 +349,7 @@ private void _unblock_model_reconstruction(int layer_id, Layer layer)
343349

344350
private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata)
345351
{
346-
// TODO: implement it.
352+
// TODO(Rinne): implement it.
347353
throw new NotImplementedException();
348354
}
349355

@@ -367,15 +373,14 @@ Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_
367373
}
368374
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER)
369375
{
370-
model = model = new Sequential(new SequentialArgs
376+
model = new Sequential(new SequentialArgs
371377
{
372378
Name = class_name
373379
});
374380
}
375381
else
376382
{
377-
// TODO: implement it.
378-
throw new NotImplementedException("Not implemented");
383+
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>());
379384
}
380385

381386
// Record this model and its layers. This will later be used to reconstruct

test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class SequentialModelLoad
2121
[TestMethod]
2222
public void SimpleModelFromSequential()
2323
{
24-
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential");
24+
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb");
2525
Debug.Assert(model is Model);
2626
var m = model as Model;
2727

0 commit comments

Comments
 (0)