Skip to content

Commit 4d0a64f

Browse files
committed
Add support for loading models from python.
1 parent 9deda78 commit 4d0a64f

File tree

23 files changed

+273
-74
lines changed

23 files changed

+273
-74
lines changed

src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,17 @@ public int GetVariableNumDims(string name)
6262
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
6363
}
6464

65-
public unsafe Tensor GetTensor(string name)
65+
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
6666
{
6767
Status status = new Status();
6868
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
6969
status.Check(true);
7070
var shape = GetVariableShape(name);
71-
var dtype = GetVariableDataType(name);
72-
return new Tensor(c_api.TF_TensorData(tensor), shape, dtype);
71+
if(dtype == TF_DataType.DtInvalid)
72+
{
73+
dtype = GetVariableDataType(name);
74+
}
75+
return new Tensor(tensor);
7376
}
7477

7578
private void ReadAllShapeAndType()

src/TensorFlowNET.Core/Checkpoint/checkpoint.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
227227
{
228228
dtype_map = reader.VariableToDataTypeMap;
229229
}
230-
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY);
230+
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);
231231

232232
Dictionary<Tensor, string> file_prefix_feed_dict;
233233
Tensor file_prefix_tensor;
@@ -249,7 +249,14 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
249249
file_prefix_feed_dict = null;
250250
}
251251
TrackableObjectGraph object_graph_proto = new();
252-
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
252+
if(object_graph_string.ndim > 0)
253+
{
254+
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
255+
}
256+
else
257+
{
258+
object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]);
259+
}
253260
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
254261
object_graph_proto: object_graph_proto,
255262
save_path: save_path,

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace Tensorflow.Functions
1313
/// </summary>
1414
public class ConcreteFunction: Trackable
1515
{
16-
FuncGraph func_graph;
17-
ForwardBackwardCall forward_backward;
16+
internal FuncGraph func_graph;
17+
internal ForwardBackwardCall forward_backward;
1818
public Tensor[] Inputs => func_graph.Inputs;
1919
public Tensor[] CapturedInputs => func_graph.external_captures;
2020

@@ -23,6 +23,8 @@ public class ConcreteFunction: Trackable
2323
public Tensor[] Outputs;
2424
public Type ReturnType;
2525
public TensorSpec[] OutputStructure;
26+
public IEnumerable<string> ArgKeywords { get; set; }
27+
public long NumPositionArgs { get; set; }
2628

2729
public ConcreteFunction(string name)
2830
{
@@ -163,6 +165,15 @@ public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs)
163165
return flat_outputs;
164166
}
165167

168+
public void AddTograph(Graph? g = null)
169+
{
170+
if(!tf.Context.executing_eagerly() && g is null)
171+
{
172+
g = ops.get_default_graph();
173+
}
174+
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
175+
}
176+
166177
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
167178
{
168179
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Newtonsoft.Json.Linq;
2+
using Newtonsoft.Json;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
7+
namespace Tensorflow.Keras.Common
8+
{
9+
public class CustomizedDTypeJsonConverter : JsonConverter
10+
{
11+
public override bool CanConvert(Type objectType)
12+
{
13+
return objectType == typeof(TF_DataType);
14+
}
15+
16+
public override bool CanRead => true;
17+
18+
public override bool CanWrite => true;
19+
20+
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
21+
{
22+
var token = JToken.FromObject(value);
23+
token.WriteTo(writer);
24+
}
25+
26+
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
27+
{
28+
if (reader.ValueType == typeof(string))
29+
{
30+
var str = (string)serializer.Deserialize(reader, typeof(string));
31+
return dtypes.tf_dtype_from_name(str);
32+
}
33+
else
34+
{
35+
return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType));
36+
}
37+
}
38+
}
39+
}

src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,54 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
4646
{
4747
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
4848
}
49-
if(values.Length != 3)
49+
if(values.Length == 1)
50+
{
51+
var array = values[0] as JArray;
52+
if(array is null)
53+
{
54+
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
55+
}
56+
values = array.ToObject<object[]>();
57+
}
58+
if (values.Length < 3)
5059
{
5160
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
5261
}
5362
if (values[0] is not string)
5463
{
5564
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
5665
}
57-
if (values[1] is not int)
66+
int nodeIndex;
67+
int tensorIndex;
68+
if (values[1] is long)
69+
{
70+
nodeIndex = (int)(long)values[1];
71+
}
72+
else if (values[1] is int)
73+
{
74+
nodeIndex = (int)values[1];
75+
}
76+
else
5877
{
5978
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
6079
}
61-
if (values[2] is not int)
80+
if (values[2] is long)
81+
{
82+
tensorIndex = (int)(long)values[2];
83+
}
84+
else if (values[1] is int)
85+
{
86+
tensorIndex = (int)values[2];
87+
}
88+
else
6289
{
6390
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
6491
}
6592
return new NodeConfig()
6693
{
6794
Name = values[0] as string,
68-
NodeIndex = (int)values[1],
69-
TensorIndex = (int)values[2]
95+
NodeIndex = nodeIndex,
96+
TensorIndex = tensorIndex
7097
};
7198
}
7299
}

src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
using Newtonsoft.Json;
2+
using Newtonsoft.Json.Linq;
23
using System;
34
using System.Collections.Generic;
45
using System.Text;
6+
using Tensorflow.Keras.ArgsDefinition;
57
using Tensorflow.Keras.Engine;
8+
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
69

710
namespace Tensorflow.Keras.Saving
811
{

src/TensorFlowNET.Core/Tensors/TF_DataType.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
namespace Tensorflow
1+
using Newtonsoft.Json;
2+
using Tensorflow.Keras.Common;
3+
4+
namespace Tensorflow
25
{
36
/// <summary>
47
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
58
/// The enum values here are identical to corresponding values in types.proto.
69
/// </summary>
10+
[JsonConverter(typeof(CustomizedDTypeJsonConverter))]
711
public enum TF_DataType
812
{
913
DtInvalid = 0,

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ public static TF_DataType tf_dtype_from_name(string name)
159159
"uint32" => TF_DataType.TF_UINT32,
160160
"int64" => TF_DataType.TF_INT64,
161161
"uint64" => TF_DataType.TF_UINT64,
162+
"float16" => TF_DataType.TF_BFLOAT16,
163+
"float32" => TF_DataType.TF_FLOAT,
162164
"single" => TF_DataType.TF_FLOAT,
165+
"float64" => TF_DataType.TF_DOUBLE,
163166
"double" => TF_DataType.TF_DOUBLE,
164167
"complex" => TF_DataType.TF_COMPLEX128,
165168
"string" => TF_DataType.TF_STRING,
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Functions;
5+
6+
namespace Tensorflow.Training.Saving.SavedModel
7+
{
8+
/// <summary>
9+
/// A class wraps a concrete function to handle different distributed contexts.
10+
/// </summary>
11+
internal class WrapperFunction: ConcreteFunction
12+
{
13+
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph)
14+
{
15+
this.forward_backward = concrete_function.forward_backward;
16+
this.Outputs = concrete_function.Outputs;
17+
this.ReturnType = concrete_function.ReturnType;
18+
this.OutputStructure = concrete_function.OutputStructure;
19+
this.ArgKeywords = concrete_function.ArgKeywords;
20+
}
21+
}
22+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Functions;
6+
using Tensorflow.Util;
7+
8+
namespace Tensorflow.Training.Saving.SavedModel
9+
{
10+
public static class function_deserialization
11+
{
12+
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
13+
IDictionary<string, ConcreteFunction> concrete_functions)
14+
{
15+
var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName];
16+
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList();
17+
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments;
18+
19+
var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
20+
concrete_function.AddTograph();
21+
return concrete_function;
22+
}
23+
24+
private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto)
25+
{
26+
// TODO(Rinne); revise the implementation.
27+
return new FunctionSpec()
28+
{
29+
Fullargspec = function_spec_proto.Fullargspec,
30+
IsMethod = function_spec_proto.IsMethod,
31+
InputSignature = function_spec_proto.InputSignature,
32+
JitCompile = function_spec_proto.JitCompile
33+
};
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)