Skip to content

Commit 8550dcc

Browse files
committed
Add missing trackable class but not implemented.
1 parent 095bf33 commit 8550dcc

File tree

8 files changed

+120
-4
lines changed

8 files changed

+120
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using Tensorflow.Util;
18+
19+
namespace Tensorflow;
20+
21+
public sealed class SafeOperationHandle : SafeTensorflowHandle
22+
{
23+
private SafeOperationHandle()
24+
{
25+
}
26+
27+
public SafeOperationHandle(IntPtr handle)
28+
: base(handle)
29+
{
30+
}
31+
32+
protected override bool ReleaseHandle()
33+
{
34+
var status = new Status();
35+
// c_api.TF_CloseSession(handle, status);
36+
c_api.TF_DeleteSession(handle, status);
37+
SetHandle(IntPtr.Zero);
38+
return true;
39+
}
40+
}

src/TensorFlowNET.Core/Tensors/Tensors.cs

+12
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ public void Insert(int index, Tensor tensor)
6565
IEnumerator IEnumerable.GetEnumerator()
6666
=> GetEnumerator();
6767

68+
public string[] StringData()
69+
{
70+
EnsureSingleTensor(this, "nnumpy");
71+
return this[0].StringData();
72+
}
73+
74+
public string StringData(int index)
75+
{
76+
EnsureSingleTensor(this, "nnumpy");
77+
return this[0].StringData(index);
78+
}
79+
6880
public NDArray numpy()
6981
{
7082
EnsureSingleTensor(this, "nnumpy");
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Tensorflow.Train;
2+
3+
namespace Tensorflow.Trackables;
4+
5+
public class Asset : Trackable
6+
{
7+
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
8+
{
9+
return (null, null);
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using Tensorflow.Train;
2+
3+
namespace Tensorflow.Trackables;
4+
5+
public class CapturableResource : Trackable
6+
{
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System.Runtime.CompilerServices;
2+
using Tensorflow.Train;
3+
4+
namespace Tensorflow.Trackables;
5+
6+
public class RestoredResource : TrackableResource
7+
{
8+
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
9+
{
10+
return (null, null);
11+
}
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Tensorflow.Train;
2+
3+
namespace Tensorflow.Trackables;
4+
5+
public class TrackableConstant : Trackable
6+
{
7+
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
8+
{
9+
return (null, null);
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
namespace Tensorflow.Trackables;
2+
3+
public class TrackableResource : CapturableResource
4+
{
5+
}

src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs

+22-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using Tensorflow.Variables;
1414
using Tensorflow.Functions;
1515
using Tensorflow.Training.Saving.SavedModel;
16+
using Tensorflow.Trackables;
1617

1718
namespace Tensorflow
1819
{
@@ -51,9 +52,13 @@ public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto,
5152
_node_filters = filters;
5253
_node_path_to_id = _convert_node_paths_to_ints();
5354
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>();
54-
foreach(var filter in filters)
55+
56+
if (filters != null)
5557
{
56-
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
58+
foreach (var filter in filters)
59+
{
60+
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
61+
}
5762
}
5863

5964
_filtered_nodes = _retrieve_all_filtered_nodes();
@@ -535,7 +540,13 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
535540
dependencies[item.Key] = nodes[item.Value];
536541
}
537542

538-
return _recreate_default(proto, node_id, dependencies);
543+
return proto.KindCase switch
544+
{
545+
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(),
546+
SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(),
547+
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(),
548+
_ => _recreate_default(proto, node_id, dependencies)
549+
};
539550
}
540551

541552
/// <summary>
@@ -549,7 +560,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
549560
return proto.KindCase switch
550561
{
551562
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id),
552-
SavedObject.KindOneofCase.Function => throw new NotImplementedException(),
563+
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null),
553564
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
554565
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
555566
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException()
@@ -609,6 +620,13 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
609620
}
610621
}
611622

623+
private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto,
624+
Dictionary<Maybe<string, int>, Trackable> dependencies)
625+
{
626+
throw new NotImplementedException();
627+
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
628+
}
629+
612630
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
613631
Dictionary<Maybe<string, int>, Trackable> dependencies)
614632
{

0 commit comments

Comments
 (0)