diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
index 8ae2dae8f..9793798d2 100644
--- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
@@ -149,4 +149,22 @@ public static void add_checkpoint_values_check(TrackableObjectGraph object_graph
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}
+
+ ///
+ /// Traverse the object graph and list all accessible objects.
+ ///
+ ///
+ public static IList list_objects(ObjectGraphView graph_view)
+ {
+ return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
+ }
+
+ internal static IEnumerable _objects_with_attributes(IEnumerable full_list)
+ {
+ return full_list.TakeWhile(x =>
+ {
+ var saveables = x.gather_saveables_for_checkpoint();
+ return saveables is not null && saveables.Count > 0;
+ });
+ }
}
diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
new file mode 100644
index 000000000..0cc8e5fbd
--- /dev/null
+++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
@@ -0,0 +1,100 @@
+using Tensorflow.Util;
+
+namespace Tensorflow.Checkpoint
+{
+ sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
+ {
+ public SafeCheckpointReaderHandle(): base()
+ {
+
+ }
+ public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
+ {
+
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ c_api.TF_DeleteCheckpointReader(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+ }
+ public class CheckpointReader
+ {
+ private SafeCheckpointReaderHandle _handle;
+ public Dictionary VariableToDataTypeMap { get; set; }
+ public Dictionary VariableToShapeMap { get; set; }
+
+ public CheckpointReader(string filename)
+ {
+ Status status = new Status();
+ _handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
+ status.Check(true);
+ ReadAllShapeAndType();
+ }
+
+ public int HasTensor(string name)
+ {
+ return c_api.TF_CheckpointReaderHasTensor(_handle, name);
+ }
+
+ ///
+ /// Get the variable name.
+ ///
+ ///
+ ///
+ public string GetVariable(int index)
+ {
+ return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
+ }
+
+ public int Size()
+ {
+ return c_api.TF_CheckpointReaderSize(_handle);
+ }
+
+ public TF_DataType GetVariableDataType(string name)
+ {
+ return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
+ }
+
+ public Shape GetVariableShape(string name)
+ {
+ int num_dims = GetVariableNumDims(name);
+ long[] dims = new long[num_dims];
+ Status status = new Status();
+ c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
+ status.Check(true);
+ return new Shape(dims);
+ }
+
+ public int GetVariableNumDims(string name)
+ {
+ return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
+ }
+
+ public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ Status status = new Status();
+ var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
+ status.Check(true);
+ return new Tensor(tensor);
+ }
+
+ private void ReadAllShapeAndType()
+ {
+ VariableToDataTypeMap = new Dictionary();
+ VariableToShapeMap = new Dictionary();
+ int size = Size();
+ for(int i = 0; i < size; i++)
+ {
+ var name = GetVariable(i);
+ var shape = GetVariableShape(name);
+ var dtype = GetVariableDataType(name);
+ VariableToDataTypeMap[name] = dtype;
+ VariableToShapeMap[name] = shape;
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
index 3267ae126..72372e410 100644
--- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
@@ -175,9 +175,9 @@ public static (IList, object?) generate_saveable_objects(
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
- var maybe_saveable = factory_data.factory;
+ var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory);
- // TODO: oneflow python has a process with callable `saveable_factory`.
+ // TODO: tensorflow python has a process with callable `saveable_factory`.
List saveables = new();
if (maybe_saveable.TryGet(out var s))
{
@@ -217,7 +217,7 @@ public static (IList, object?) generate_saveable_objects(
public record class CheckpointFactoryData
(
- Maybe factory,
+ Func> factory,
string name,
string checkpoint_key
);
diff --git a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs
new file mode 100644
index 000000000..f956e3337
--- /dev/null
+++ b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs
@@ -0,0 +1,27 @@
+using System.Runtime.InteropServices;
+using Tensorflow.Checkpoint;
+
+namespace Tensorflow
+{
+ public unsafe partial class c_api
+ {
+ [DllImport(TensorFlowLibName)]
+ internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status);
+ [DllImport(TensorFlowLibName)]
+ internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
+ [DllImport(TensorFlowLibName)]
+ internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name);
+ [DllImport(TensorFlowLibName)]
+ internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index);
+ [DllImport(TensorFlowLibName)]
+ internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader);
+ [DllImport(TensorFlowLibName)]
+ internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name);
+ [DllImport(TensorFlowLibName)]
+ internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
+ [DllImport(TensorFlowLibName)]
+ internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name);
+ [DllImport(TensorFlowLibName)]
+ internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
index 0c2862dac..1934ffd5f 100644
--- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
@@ -6,8 +6,12 @@
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
+using Tensorflow.Exceptions;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;
+using Tensorflow.Operations;
+using Newtonsoft.Json;
+using Tensorflow.Training;
namespace Tensorflow.Checkpoint;
@@ -21,8 +25,20 @@ public class TrackableSaver
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
+ private Tensor? _file_prefix_placeholder = null;
private Dictionary? _object_map = null;
private object? _cache = null;
+ public Tensor? FilePrefixPlaceHolder
+ {
+ get
+ {
+ return _file_prefix_placeholder;
+ }
+ set
+ {
+ _file_prefix_placeholder = value;
+ }
+ }
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
@@ -192,4 +208,366 @@ public Tensor save(string file_prefix, int? checkpoint_number = null, Session? s
return save_path;
}
}
+
+ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
+ {
+ if (options is null)
+ {
+ options = new CheckpointOptions();
+ }
+ if(save_path is null)
+ {
+ return new InitializationOnlyStatus(_graph_view, ops.uid());
+ }
+
+ CheckpointReader reader = new CheckpointReader(save_path);
+ bool graph_building = tf.Context.executing_eagerly();
+ Dictionary dtype_map = null;
+ if (!graph_building)
+ {
+ dtype_map = reader.VariableToDataTypeMap;
+ }
+ Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);
+
+ Dictionary file_prefix_feed_dict;
+ Tensor file_prefix_tensor;
+ if (graph_building)
+ {
+ if(_file_prefix_placeholder is null)
+ {
+ tf.device("/cpu:0");
+ _file_prefix_placeholder = constant_op.constant("model");
+ }
+ file_prefix_tensor = _file_prefix_placeholder;
+ file_prefix_feed_dict = new();
+ file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
+ }
+ else
+ {
+ tf.device("/cpu:0");
+ file_prefix_tensor = constant_op.constant(save_path);
+ file_prefix_feed_dict = null;
+ }
+ TrackableObjectGraph object_graph_proto = new();
+ if(object_graph_string.ndim > 0)
+ {
+ object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
+ }
+ else
+ {
+ object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]);
+ }
+ CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
+ object_graph_proto: object_graph_proto,
+ save_path: save_path,
+ save_path_tensor: file_prefix_tensor,
+ reader: reader,
+ restore_op_cache: null,
+ graph_view: _graph_view,
+ options: options,
+ saveables_cache: null
+ );
+
+ new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root);
+
+ if(_graph_view.AttachedDependencies is not null)
+ {
+ foreach(var refer in _graph_view.AttachedDependencies)
+ {
+ if(refer.Name == "root")
+ {
+ continue;
+ }
+ int? proto_id = null;
+ // Find proto ID of attached dependency (if it is in the proto).
+ foreach (var proto_refer in object_graph_proto.Nodes[0].Children)
+ {
+ if(proto_refer.LocalName == refer.Name)
+ {
+ proto_id = proto_refer.NodeId;
+ break;
+ }
+ }
+
+ if (proto_id is null)
+ {
+ continue;
+ }
+
+ // Object has already been restored. This can happen when there's an
+ // indirect connection from the attached object to the root.
+ if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value))
+ {
+ continue;
+ }
+
+ new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer);
+ }
+ }
+
+ return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view);
+ }
+}
+
+public class CheckpointRestoreCoordinator
+{
+ private CheckpointOptions _options;
+ private TrackableObjectGraph _object_graph_proto;
+ private int _restore_uid;
+ private HashSet _matched_proto_ids;
+ private Tensor _save_path_tensor;
+ private string _save_path_string;
+ private CheckpointReader _reader;
+ private Dictionary _dtype_map;
+ private Dictionary _shape_map;
+ private ObjectGraphView _graph_view;
+ private Dictionary> _slot_restorations;
+ private bool _expect_partial_attr;
+ private List _restore_ops;
+ private List _all_trackables;
+ private Dictionary _object_by_proto_id;
+ private Dictionary _restore_ops_by_name;
+ private Dictionary> _deferred_slot_restorations;
+ private Dictionary> _unused_attributes;
+
+ public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor,
+ CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache)
+ {
+ // TODO(Rinne): cache.
+ _options = options;
+ _object_graph_proto = object_graph_proto;
+ _restore_uid = ops.uid();
+ _save_path_tensor = save_path_tensor;
+ _save_path_string = save_path;
+ _reader = reader;
+ if(_reader is null)
+ {
+ _reader = new CheckpointReader(save_path);
+ }
+ _dtype_map = _reader.VariableToDataTypeMap;
+ _shape_map = _reader.VariableToShapeMap;
+ _graph_view = graph_view;
+ _restore_ops = new List();
+ _restore_ops_by_name = new Dictionary();
+ _all_trackables = new List();
+ _matched_proto_ids = new HashSet();
+ _object_by_proto_id = new Dictionary();
+ _slot_restorations = new Dictionary>();
+ _deferred_slot_restorations = new Dictionary>();
+
+ _expect_partial_attr = false;
+ for(int i = 0; i < _object_graph_proto.Nodes.Count; i++)
+ {
+ var node = _object_graph_proto.Nodes[i];
+ foreach(var slot_reference in node.SlotVariables)
+ {
+ _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List())
+ .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName));
+ }
+ }
+
+ // skip the deleter and cache.
+ }
+
+ public bool ExpectPartial
+ {
+ get
+ {
+ return _expect_partial_attr;
+ }
+ set
+ {
+ _expect_partial_attr = value;
+ }
+ }
+
+ ///
+ /// Corresponding to `all_python_objects` of tensorflow python
+ ///
+ public List AllTrackables => _all_trackables;
+ public HashSet MatchedProtoIds => _matched_proto_ids;
+ public Dictionary ObjectByProtoId => _object_by_proto_id;
+ public int RestoreUid => _restore_uid;
+ public TrackableObjectGraph ObjectGraphProto => _object_graph_proto;
+ public Dictionary> SlotRestorations => _slot_restorations;
+ public Dictionary> DeferredSlotRestorations => _deferred_slot_restorations;
+ public Dictionary RestoreOpsByName => _restore_ops_by_name;
+ public Dictionary> UnusedAttributes => _unused_attributes;
+
+ public void new_restore_ops(IEnumerable new_ops)
+ {
+ _restore_ops.AddRange(new_ops);
+ // skip the callback.
+ }
+
+ public List restore_saveables(Dictionary> tensor_saveables, List positions, object? registered_savers = null)
+ {
+ List restore_ops = new();
+ foreach(var position in positions)
+ {
+ var key = position.ObjectProto.Attributes[0].CheckpointKey;
+ throw new NotImplementedException();
+ }
+
+ Dictionary variable_dict = new();
+ foreach(var item in tensor_saveables)
+ {
+ if(item.Value.TryGet(out var variable))
+ {
+ variable_dict[item.Key] = variable;
+ }
+ else
+ {
+ throw new TypeError();
+ }
+ }
+
+ if (tensor_saveables is not null && tensor_saveables.Count > 0)
+ {
+ var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict);
+ var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options);
+ if (!tf.Context.executing_eagerly())
+ {
+ foreach(var item in new_restore_ops)
+ {
+ restore_ops.Add(item.Value);
+ Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key));
+ _restore_ops_by_name[item.Key] = item.Value;
+ }
+ }
+ }
+ return restore_ops;
+ }
+}
+
+public abstract class LoadStatus
+{
+ public abstract LoadStatus assert_consumed();
+ public abstract LoadStatus assert_existing_objects_matched();
+ public abstract LoadStatus assert_nontrivial_match();
+ public abstract LoadStatus run_restore_ops(Session? session = null);
+ public abstract void initialize_or_restore(Session? session = null);
+ public virtual LoadStatus expect_partial()
+ {
+ return this;
+ }
+}
+
+public class InitializationOnlyStatus: LoadStatus
+{
+ private int _restore_uid;
+ private ObjectGraphView _object_graph_view;
+ private Trackable _root;
+ public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid)
+ {
+ _restore_uid = restore_uid;
+ _object_graph_view = object_graph_view;
+ _root = object_graph_view.Root;
+ }
+ public override LoadStatus assert_consumed()
+ {
+ throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
+ }
+ public override LoadStatus assert_existing_objects_matched()
+ {
+ throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
+ }
+ public override LoadStatus assert_nontrivial_match()
+ {
+ throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
+ }
+ public override LoadStatus run_restore_ops(Session? session = null)
+ {
+ throw new AssertionError("No checkpoint specified, so no restore ops are available "
+ + "(save_path=None to Saver.restore).");
+ }
+ public override void initialize_or_restore(Session? session = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ return;
+ }
+ if(session is null)
+ {
+ session = new Session();
+ }
+ var trackable_objects = CheckPointUtils.list_objects(_object_graph_view);
+ throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
+ }
+}
+
+internal class CheckpointLoadStatus: LoadStatus
+{
+ private CheckpointRestoreCoordinator _checkpoint;
+ private Dictionary _feed_dict;
+ private ObjectGraphView _object_graph_view;
+ private Trackable _root;
+ public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary feed_dict, ObjectGraphView graph_view):base()
+ {
+ _checkpoint = checkpoint;
+ _feed_dict = feed_dict;
+ _object_graph_view = graph_view;
+ _root = graph_view.Root;
+ }
+
+ public CheckpointRestoreCoordinator Checkpoint => _checkpoint;
+
+ public override LoadStatus assert_consumed()
+ {
+ throw new NotImplementedException();
+ }
+
+ public override LoadStatus assert_existing_objects_matched()
+ {
+ for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++)
+ {
+ var node = _checkpoint.ObjectGraphProto.Nodes[i];
+ if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) &&
+ trackable.UpdateUid < _checkpoint.RestoreUid)
+ {
+ throw new AssertionError($"Object {node} not assigned a value from checkpoint.");
+ }
+ }
+ foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view))
+ {
+ if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0)
+ {
+ continue;
+ }
+ _checkpoint.AllTrackables.Add(trackable_object);
+ }
+ var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables)
+ .Except(_checkpoint.ObjectByProtoId.Values);
+ if (unused_trackables.Any())
+ {
+ var num_unused_trackables = unused_trackables.Count();
+ var num_variables_to_show = Math.Min(10, num_unused_trackables);
+ throw new AssertionError($"Found {num_unused_trackables} Python objects that were " +
+ $"not bound to checkpointed values, likely due to changes in the " +
+ $"Python program. Showing {num_variables_to_show} of " +
+ $"{num_unused_trackables} unmatched objects: " +
+ $"{{list(unused_python_objects)[:num_variables_to_show]}}");
+ }
+ return this;
+ }
+
+ public override LoadStatus assert_nontrivial_match()
+ {
+ throw new NotImplementedException();
+ }
+
+ public override LoadStatus expect_partial()
+ {
+ throw new NotImplementedException();
+ }
+
+ public override void initialize_or_restore(Session? session = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override LoadStatus run_restore_ops(Session? session = null)
+ {
+ throw new NotImplementedException();
+ }
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
index 09904d684..96e6c8dd9 100644
--- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
@@ -213,7 +213,7 @@ public IDictionary> restore(Tensor file_pref
// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
- var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
+ var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
Dictionary> restored_tensor_dict = new();
int idx = 0;
diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs
new file mode 100644
index 000000000..b27396a79
--- /dev/null
+++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs
@@ -0,0 +1,331 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using Tensorflow.Train;
+using Tensorflow.Training;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Checkpoint;
+
+public class CheckpointPosition
+{
+ private CheckpointRestoreCoordinator _checkpoint;
+ private int _proto_id;
+ private bool _skip_restore;
+ public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id)
+ {
+ _checkpoint = checkpoint;
+ _proto_id = proto_id;
+ _skip_restore = false;
+ }
+
+ public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id];
+ public CheckpointRestoreCoordinator Checkpoint => _checkpoint;
+ public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id];
+
+ public void restore(Trackable trackable)
+ {
+ using (ops.init_scope())
+ {
+ if (bind_project(trackable))
+ {
+ var restore_ops = _restore_descendants();
+ if(restore_ops is not null && restore_ops.Count > 0)
+ {
+ _checkpoint.new_restore_ops(restore_ops);
+ }
+ }
+ }
+ }
+
+ ///
+ /// Set a checkpoint<->object correspondence.
+ ///
+ ///
+ ///
+ public bool bind_project(Trackable trackable)
+ {
+ _checkpoint.AllTrackables.Add(trackable);
+ _checkpoint.MatchedProtoIds.Add(_proto_id);
+ if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment))
+ {
+ // skip the `logging.warning`.
+ return false;
+ }
+ else
+ {
+ _checkpoint.ObjectByProtoId[_proto_id] = trackable;
+ return true;
+ }
+ }
+
+ public (List, Dictionary>, List, object?) gather_ops_or_named_saveables()
+ {
+ // skip the registered_saver
+
+ if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0)
+ {
+ return (new List(), new Dictionary>(),
+ new List(), null);
+ }
+
+ var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable);
+
+ List existing_restore_ops;
+ List positions = new();
+ Dictionary> named_saveables;
+ if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
+ {
+ (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories);
+ }
+ else if(saveable_factories.Count > 0)
+ {
+ (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories);
+ }
+ else
+ {
+ throw new NotImplementedException();
+ }
+ return (existing_restore_ops, named_saveables, positions, null);
+ }
+
+ public CheckpointPosition create_child_position(int node_id)
+ {
+ return new CheckpointPosition(_checkpoint, node_id);
+ }
+
+ public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable,
+ int slot_variable_id, string slot_name)
+ {
+ //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id);
+
+ // TODO(Rinne): implement it.
+ return (null, null);
+ }
+
+ ///
+ /// Creates a saveable using the _serialize_to_tensor method.
+ ///
+ ///
+ private (List, Dictionary>) _create_serialize_to_tensor_saveable(
+ IDictionary>> saveable_factories)
+ {
+ string suffix = SaveableCompat.get_saveable_name(this.Trackable);
+ suffix = suffix ?? "";
+ var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix;
+
+ if (!tf.Context.executing_eagerly())
+ {
+ throw new NotImplementedException("The restore under graph mode has not been implemented. " +
+ "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
+ }
+
+ var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name);
+ // skip the cache.
+ Dictionary> dict = new();
+ dict[saveable_name] = saveable;
+ return (new List(), dict);
+ }
+
+ private (List, Dictionary>) _create_saveables_by_attribute_name(
+ IDictionary>> saveable_factories)
+ {
+ // TODO(Rinne): implement it.
+ if(ObjectProto.Attributes is null)
+ {
+ return (new List(), new Dictionary>());
+ }
+
+ List existing_restore_ops = new();
+ HashSet created_compat_names = new();
+ Dictionary> named_saveables = new();
+ foreach (var serialized_tensor in ObjectProto.Attributes)
+ {
+ Operation existing_op;
+ if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey))
+ {
+ existing_op = null;
+ }
+ else
+ {
+ existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey];
+ }
+
+ if(existing_op is not null)
+ {
+ existing_restore_ops.Add(existing_op);
+ continue;
+ }
+
+ if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x)))
+ {
+ continue;
+ }
+
+ // TODO(Rinne): deal with cache.
+
+ var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names);
+ if(saveable is null)
+ {
+ _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List()).Add(serialized_tensor.Name);
+ continue;
+ }
+ named_saveables[serialized_tensor.CheckpointKey] = saveable;
+ }
+ return (existing_restore_ops, named_saveables);
+ }
+
+ private Maybe _get_saveable_from_factory(IDictionary>> saveable_factories,
+ TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet created_compat_names)
+ {
+ var expected_factory_name = serialized_tensor.Name;
+ var factory_input_name = serialized_tensor.CheckpointKey;
+
+ if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory))
+ {
+ foreach(var item in saveable_factories)
+ {
+ var factory_name = item.Key;
+ var factory = item.Value;
+ if (expected_factory_name.StartsWith(factory_name))
+ {
+ if(matched_factory is not null)
+ {
+ throw new ValueError($"Forward compatibility load error: Unable to load " +
+ "checkpoint saved in future version of TensorFlow. " +
+ "Please update your version of TensorFlow to the " +
+ "version in which the checkpoint was saved.");
+ }
+ }
+ matched_factory = factory;
+ factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name;
+ created_compat_names.Add(factory_name);
+ }
+ }
+ return matched_factory(factory_input_name);
+ }
+
+ private string _extract_saveable_name(string checkpoint_key)
+ {
+ var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/";
+ return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length);
+ }
+
+ ///
+ /// Restore the bound Trackable and dependencies (may be deferred).
+ ///
+ private List _restore_descendants()
+ {
+ Queue<(CheckpointPosition, Trackable)> visit_queue = new();
+ visit_queue.Enqueue((this, this.Trackable));
+ List restore_ops = new();
+ Dictionary> tensor_saveables = new();
+ List positions = new();
+
+ CheckpointPosition current_position = null;
+ while (visit_queue.Count > 0)
+ {
+ current_position = visit_queue.Dequeue().Item1;
+ var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore();
+ restore_ops.AddRange(new_restore_ops);
+ foreach(var item in new_tensor_saveables)
+ {
+ tensor_saveables.Add(item.Key, item.Value);
+ }
+ positions.AddRange(new_positions);
+ _queue_children_for_restoration(current_position, visit_queue);
+ _queue_slot_variables(current_position, visit_queue);
+ }
+ restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null));
+ return restore_ops;
+ }
+
+ private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue)
+ {
+ var trackable = checkpoint_position.Trackable;
+ foreach(var child in checkpoint_position.ObjectProto.Children)
+ {
+ var child_position = checkpoint_position.create_child_position(child.NodeId);
+ var local_object = trackable._lookup_dependency(child.LocalName);
+ var child_proto = child_position.ObjectProto;
+ if(local_object is null)
+ {
+ if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any())
+ {
+ trackable.DeferredDependencies.SetDefault(child.LocalName, new List()).Add(child_position);
+ }
+ }
+ else
+ {
+ if (child_position.bind_project(local_object))
+ {
+ visit_queue.Enqueue((child_position, local_object));
+ }
+ }
+ }
+ }
+
+ private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue)
+ {
+ var trackable = checkpoint_position.Trackable;
+ var checkpoint = checkpoint_position.Checkpoint;
+ if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions))
+ {
+ checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id);
+ foreach (var deferred_slot_restoration in positions)
+ {
+ var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position(
+ trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId,
+ deferred_slot_restoration.SlotName
+ );
+ if(slot_variable_position is not null)
+ {
+ visit_queue.Enqueue((slot_variable_position, slot_variable));
+ }
+ }
+ }
+ if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations))
+ {
+ checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id);
+ foreach (var slot_restoration in restorations)
+ {
+ if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object))
+ {
+ throw new NotImplementedException();
+ // TODO(Rinne); implement it.
+ }
+ else
+ {
+ Debug.Assert(trackable is BaseResourceVariable);
+ Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List())
+ .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName));
+ }
+ }
+ }
+ }
+
+ private (List, Dictionary>, List, object?) _single_restore()
+ {
+ var trackable = this.Trackable;
+ trackable._maybe_initialize_trackable();
+ if(_checkpoint.RestoreUid > trackable.UpdateUid)
+ {
+ var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables();
+ trackable.UpdateUid = _checkpoint.RestoreUid;
+ return (restore_ops, tensor_saveables, positions, registered_savers);
+ }
+ else
+ {
+ return (new List(), new Dictionary>(),
+ new List(), null);
+ }
+ }
+}
+
+public record class DeferredSlotVariableRestoration(
+ BaseResourceVariable OriginalVariable,
+ int SlotVariableId,
+ string SlotName
+);
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs
index cb3ea4d3c..2926f8e28 100644
--- a/src/TensorFlowNET.Core/Eager/execute.cs
+++ b/src/TensorFlowNET.Core/Eager/execute.cs
@@ -10,7 +10,7 @@
namespace Tensorflow.Eager
{
- internal class execute
+ internal static class execute
{
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{
@@ -27,5 +27,9 @@ public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] i
return tensors;
}
+ public static bool must_record_gradient()
+ {
+ return false;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
index bac9cedbf..a6720a5f3 100644
--- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
@@ -13,8 +13,8 @@ namespace Tensorflow.Functions
///
public class ConcreteFunction: Trackable
{
- FuncGraph func_graph;
- ForwardBackwardCall forward_backward;
+ internal FuncGraph func_graph;
+ internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;
@@ -23,6 +23,8 @@ public class ConcreteFunction: Trackable
public Tensor[] Outputs;
public Type ReturnType;
public TensorSpec[] OutputStructure;
+ public IEnumerable ArgKeywords { get; set; }
+ public long NumPositionArgs { get; set; }
public ConcreteFunction(string name)
{
@@ -163,6 +165,15 @@ public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs)
return flat_outputs;
}
+ public void AddTograph(Graph? g = null)
+ {
+ if(!tf.Context.executing_eagerly() && g is null)
+ {
+ g = ops.get_default_graph();
+ }
+ // TODO(Rinne); complete it with `_delayed_rewrite_functions`.
+ }
+
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs
index 5f08702da..142b8b64e 100644
--- a/src/TensorFlowNET.Core/IO/gfile.cs
+++ b/src/TensorFlowNET.Core/IO/gfile.cs
@@ -16,8 +16,10 @@ limitations under the License.
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.IO;
using System.Linq;
+using static Tensorflow.Binding;
namespace Tensorflow.IO
{
@@ -63,5 +65,15 @@ public string[] glob(string data_dir)
dirs.AddRange(Directory.GetFiles(dir));
return dirs.ToArray();
}
+
+ public string join(params string[] paths)
+ {
+ Debug.Assert(paths.Length >= 1);
+ if (paths[0].Substring(1).Contains("://"))
+ {
+ throw new NotImplementedException("The combination of urls has not been implemented.");
+ }
+ return Path.Combine(paths);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
index 4e190605c..f6087a43a 100644
--- a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
@@ -37,7 +37,16 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
- var axis = serializer.Deserialize(reader, typeof(long[]));
+ int[]? axis;
+ if(reader.ValueType == typeof(long))
+ {
+ axis = new int[1];
+ axis[0] = (int)serializer.Deserialize(reader, typeof(int));
+ }
+ else
+ {
+ axis = serializer.Deserialize(reader, typeof(int[])) as int[];
+ }
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
new file mode 100644
index 000000000..fce7bec58
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
@@ -0,0 +1,36 @@
+using Newtonsoft.Json.Linq;
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.Common
+{
+ public class CustomizedDTypeJsonConverter : JsonConverter
+ {
+ public override bool CanConvert(Type objectType)
+ {
+ return objectType == typeof(TF_DataType);
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanWrite => true;
+
+ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
+ {
+ var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
+ token.WriteTo(writer);
+ }
+
+ public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
+ {
+ if (reader.ValueType == typeof(string))
+ {
+ var str = (string)serializer.Deserialize(reader, typeof(string));
+ return dtypes.tf_dtype_from_name(str);
+ }
+ else
+ {
+ return (TF_DataType)serializer.Deserialize(reader, typeof(int));
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
index 1ad19fc89..cfd8ee8f7 100644
--- a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
@@ -46,7 +46,16 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
- if(values.Length != 3)
+ if(values.Length == 1)
+ {
+ var array = values[0] as JArray;
+ if(array is null)
+ {
+ throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
+ }
+ values = array.ToObject