diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 8ae2dae8f..9793798d2 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -149,4 +149,22 @@ public static void add_checkpoint_values_check(TrackableObjectGraph object_graph // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // } } + + /// + /// Traverse the object graph and list all accessible objects. + /// + /// + public static IList list_objects(ObjectGraphView graph_view) + { + return objects_ids_and_slot_variables_and_paths(graph_view).Item1; + } + + internal static IEnumerable _objects_with_attributes(IEnumerable full_list) + { + return full_list.TakeWhile(x => + { + var saveables = x.gather_saveables_for_checkpoint(); + return saveables is not null && saveables.Count > 0; + }); + } } diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs new file mode 100644 index 000000000..0cc8e5fbd --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -0,0 +1,100 @@ +using Tensorflow.Util; + +namespace Tensorflow.Checkpoint +{ + sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle + { + public SafeCheckpointReaderHandle(): base() + { + + } + public SafeCheckpointReaderHandle(IntPtr handle): base(handle) + { + + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteCheckpointReader(handle); + SetHandle(IntPtr.Zero); + return true; + } + } + public class CheckpointReader + { + private SafeCheckpointReaderHandle _handle; + public Dictionary VariableToDataTypeMap { get; set; } + public Dictionary VariableToShapeMap { get; set; } + + public CheckpointReader(string filename) + { + Status status = new Status(); + _handle = c_api.TF_NewCheckpointReader(filename, status.Handle); + status.Check(true); + ReadAllShapeAndType(); + } + + public int HasTensor(string name) + { + return c_api.TF_CheckpointReaderHasTensor(_handle, name); + } + + /// + /// Get the variable name. + /// + /// + /// + public string GetVariable(int index) + { + return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); + } + + public int Size() + { + return c_api.TF_CheckpointReaderSize(_handle); + } + + public TF_DataType GetVariableDataType(string name) + { + return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); + } + + public Shape GetVariableShape(string name) + { + int num_dims = GetVariableNumDims(name); + long[] dims = new long[num_dims]; + Status status = new Status(); + c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); + status.Check(true); + return new Shape(dims); + } + + public int GetVariableNumDims(string name) + { + return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); + } + + public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) + { + Status status = new Status(); + var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); + status.Check(true); + return new Tensor(tensor); + } + + private void ReadAllShapeAndType() + { + VariableToDataTypeMap = new Dictionary(); + VariableToShapeMap = new Dictionary(); + int size = Size(); + for(int i = 0; i < size; i++) + { + var name = GetVariable(i); + var shape = GetVariableShape(name); + var dtype = GetVariableDataType(name); + VariableToDataTypeMap[name] = dtype; + VariableToShapeMap[name] = shape; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 3267ae126..72372e410 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -175,9 +175,9 @@ public static (IList, object?) generate_saveable_objects( { var name = factory_data.name; var key = factory_data.checkpoint_key; - var maybe_saveable = factory_data.factory; + var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); - // TODO: oneflow python has a process with callable `saveable_factory`. + // TODO: tensorflow python has a process with callable `saveable_factory`. List saveables = new(); if (maybe_saveable.TryGet(out var s)) { @@ -217,7 +217,7 @@ public static (IList, object?) generate_saveable_objects( public record class CheckpointFactoryData ( - Maybe factory, + Func> factory, string name, string checkpoint_key ); diff --git a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs new file mode 100644 index 000000000..f956e3337 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs @@ -0,0 +1,27 @@ +using System.Runtime.InteropServices; +using Tensorflow.Checkpoint; + +namespace Tensorflow +{ + public unsafe partial class c_api + { + [DllImport(TensorFlowLibName)] + internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + internal static extern void TF_DeleteCheckpointReader(IntPtr reader); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader); + [DllImport(TensorFlowLibName)] + internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 0c2862dac..1934ffd5f 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -6,8 +6,12 @@ using Tensorflow.Contexts; using Tensorflow.Eager; using Tensorflow.Train; +using Tensorflow.Exceptions; using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; using static Tensorflow.Binding; +using Tensorflow.Operations; +using Newtonsoft.Json; +using Tensorflow.Training; namespace Tensorflow.Checkpoint; @@ -21,8 +25,20 @@ public class TrackableSaver private TrackableObjectGraph _last_save_object_graph; private Tensor? _object_graph_feed_tensor = null; private Tensor? _file_prefix_feed_tensor = null; + private Tensor? _file_prefix_placeholder = null; private Dictionary? _object_map = null; private object? _cache = null; + public Tensor? FilePrefixPlaceHolder + { + get + { + return _file_prefix_placeholder; + } + set + { + _file_prefix_placeholder = value; + } + } public TrackableSaver(ObjectGraphView graph_view) { _graph_view = graph_view; @@ -192,4 +208,366 @@ public Tensor save(string file_prefix, int? checkpoint_number = null, Session? s return save_path; } } + + public LoadStatus restore(string? save_path, CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + if(save_path is null) + { + return new InitializationOnlyStatus(_graph_view, ops.uid()); + } + + CheckpointReader reader = new CheckpointReader(save_path); + bool graph_building = tf.Context.executing_eagerly(); + Dictionary dtype_map = null; + if (!graph_building) + { + dtype_map = reader.VariableToDataTypeMap; + } + Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); + + Dictionary file_prefix_feed_dict; + Tensor file_prefix_tensor; + if (graph_building) + { + if(_file_prefix_placeholder is null) + { + tf.device("/cpu:0"); + _file_prefix_placeholder = constant_op.constant("model"); + } + file_prefix_tensor = _file_prefix_placeholder; + file_prefix_feed_dict = new(); + file_prefix_feed_dict[_file_prefix_placeholder] = save_path; + } + else + { + tf.device("/cpu:0"); + file_prefix_tensor = constant_op.constant(save_path); + file_prefix_feed_dict = null; + } + TrackableObjectGraph object_graph_proto = new(); + if(object_graph_string.ndim > 0) + { + object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); + } + else + { + object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); + } + CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( + object_graph_proto: object_graph_proto, + save_path: save_path, + save_path_tensor: file_prefix_tensor, + reader: reader, + restore_op_cache: null, + graph_view: _graph_view, + options: options, + saveables_cache: null + ); + + new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); + + if(_graph_view.AttachedDependencies is not null) + { + foreach(var refer in _graph_view.AttachedDependencies) + { + if(refer.Name == "root") + { + continue; + } + int? proto_id = null; + // Find proto ID of attached dependency (if it is in the proto). + foreach (var proto_refer in object_graph_proto.Nodes[0].Children) + { + if(proto_refer.LocalName == refer.Name) + { + proto_id = proto_refer.NodeId; + break; + } + } + + if (proto_id is null) + { + continue; + } + + // Object has already been restored. This can happen when there's an + // indirect connection from the attached object to the root. + if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) + { + continue; + } + + new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); + } + } + + return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); + } +} + +public class CheckpointRestoreCoordinator +{ + private CheckpointOptions _options; + private TrackableObjectGraph _object_graph_proto; + private int _restore_uid; + private HashSet _matched_proto_ids; + private Tensor _save_path_tensor; + private string _save_path_string; + private CheckpointReader _reader; + private Dictionary _dtype_map; + private Dictionary _shape_map; + private ObjectGraphView _graph_view; + private Dictionary> _slot_restorations; + private bool _expect_partial_attr; + private List _restore_ops; + private List _all_trackables; + private Dictionary _object_by_proto_id; + private Dictionary _restore_ops_by_name; + private Dictionary> _deferred_slot_restorations; + private Dictionary> _unused_attributes; + + public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, + CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) + { + // TODO(Rinne): cache. + _options = options; + _object_graph_proto = object_graph_proto; + _restore_uid = ops.uid(); + _save_path_tensor = save_path_tensor; + _save_path_string = save_path; + _reader = reader; + if(_reader is null) + { + _reader = new CheckpointReader(save_path); + } + _dtype_map = _reader.VariableToDataTypeMap; + _shape_map = _reader.VariableToShapeMap; + _graph_view = graph_view; + _restore_ops = new List(); + _restore_ops_by_name = new Dictionary(); + _all_trackables = new List(); + _matched_proto_ids = new HashSet(); + _object_by_proto_id = new Dictionary(); + _slot_restorations = new Dictionary>(); + _deferred_slot_restorations = new Dictionary>(); + + _expect_partial_attr = false; + for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) + { + var node = _object_graph_proto.Nodes[i]; + foreach(var slot_reference in node.SlotVariables) + { + _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List()) + .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); + } + } + + // skip the deleter and cache. + } + + public bool ExpectPartial + { + get + { + return _expect_partial_attr; + } + set + { + _expect_partial_attr = value; + } + } + + /// + /// Corresponding to `all_python_objects` of tensorflow python + /// + public List AllTrackables => _all_trackables; + public HashSet MatchedProtoIds => _matched_proto_ids; + public Dictionary ObjectByProtoId => _object_by_proto_id; + public int RestoreUid => _restore_uid; + public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; + public Dictionary> SlotRestorations => _slot_restorations; + public Dictionary> DeferredSlotRestorations => _deferred_slot_restorations; + public Dictionary RestoreOpsByName => _restore_ops_by_name; + public Dictionary> UnusedAttributes => _unused_attributes; + + public void new_restore_ops(IEnumerable new_ops) + { + _restore_ops.AddRange(new_ops); + // skip the callback. + } + + public List restore_saveables(Dictionary> tensor_saveables, List positions, object? registered_savers = null) + { + List restore_ops = new(); + foreach(var position in positions) + { + var key = position.ObjectProto.Attributes[0].CheckpointKey; + throw new NotImplementedException(); + } + + Dictionary variable_dict = new(); + foreach(var item in tensor_saveables) + { + if(item.Value.TryGet(out var variable)) + { + variable_dict[item.Key] = variable; + } + else + { + throw new TypeError(); + } + } + + if (tensor_saveables is not null && tensor_saveables.Count > 0) + { + var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); + var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); + if (!tf.Context.executing_eagerly()) + { + foreach(var item in new_restore_ops) + { + restore_ops.Add(item.Value); + Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); + _restore_ops_by_name[item.Key] = item.Value; + } + } + } + return restore_ops; + } +} + +public abstract class LoadStatus +{ + public abstract LoadStatus assert_consumed(); + public abstract LoadStatus assert_existing_objects_matched(); + public abstract LoadStatus assert_nontrivial_match(); + public abstract LoadStatus run_restore_ops(Session? session = null); + public abstract void initialize_or_restore(Session? session = null); + public virtual LoadStatus expect_partial() + { + return this; + } +} + +public class InitializationOnlyStatus: LoadStatus +{ + private int _restore_uid; + private ObjectGraphView _object_graph_view; + private Trackable _root; + public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) + { + _restore_uid = restore_uid; + _object_graph_view = object_graph_view; + _root = object_graph_view.Root; + } + public override LoadStatus assert_consumed() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override LoadStatus assert_existing_objects_matched() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override LoadStatus assert_nontrivial_match() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override LoadStatus run_restore_ops(Session? session = null) + { + throw new AssertionError("No checkpoint specified, so no restore ops are available " + + "(save_path=None to Saver.restore)."); + } + public override void initialize_or_restore(Session? session = null) + { + if (tf.Context.executing_eagerly()) + { + return; + } + if(session is null) + { + session = new Session(); + } + var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } +} + +internal class CheckpointLoadStatus: LoadStatus +{ + private CheckpointRestoreCoordinator _checkpoint; + private Dictionary _feed_dict; + private ObjectGraphView _object_graph_view; + private Trackable _root; + public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary feed_dict, ObjectGraphView graph_view):base() + { + _checkpoint = checkpoint; + _feed_dict = feed_dict; + _object_graph_view = graph_view; + _root = graph_view.Root; + } + + public CheckpointRestoreCoordinator Checkpoint => _checkpoint; + + public override LoadStatus assert_consumed() + { + throw new NotImplementedException(); + } + + public override LoadStatus assert_existing_objects_matched() + { + for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) + { + var node = _checkpoint.ObjectGraphProto.Nodes[i]; + if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && + trackable.UpdateUid < _checkpoint.RestoreUid) + { + throw new AssertionError($"Object {node} not assigned a value from checkpoint."); + } + } + foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) + { + if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) + { + continue; + } + _checkpoint.AllTrackables.Add(trackable_object); + } + var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) + .Except(_checkpoint.ObjectByProtoId.Values); + if (unused_trackables.Any()) + { + var num_unused_trackables = unused_trackables.Count(); + var num_variables_to_show = Math.Min(10, num_unused_trackables); + throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + + $"not bound to checkpointed values, likely due to changes in the " + + $"Python program. Showing {num_variables_to_show} of " + + $"{num_unused_trackables} unmatched objects: " + + $"{{list(unused_python_objects)[:num_variables_to_show]}}"); + } + return this; + } + + public override LoadStatus assert_nontrivial_match() + { + throw new NotImplementedException(); + } + + public override LoadStatus expect_partial() + { + throw new NotImplementedException(); + } + + public override void initialize_or_restore(Session? session = null) + { + throw new NotImplementedException(); + } + + public override LoadStatus run_restore_ops(Session? session = null) + { + throw new NotImplementedException(); + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 09904d684..96e6c8dd9 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -213,7 +213,7 @@ public IDictionary> restore(Tensor file_pref // tf python has code `with ops.device(restore_device):` here. tf.device(restore_device); // may be risky. - var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); Dictionary> restored_tensor_dict = new(); int idx = 0; diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs new file mode 100644 index 000000000..b27396a79 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs @@ -0,0 +1,331 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +public class CheckpointPosition +{ + private CheckpointRestoreCoordinator _checkpoint; + private int _proto_id; + private bool _skip_restore; + public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) + { + _checkpoint = checkpoint; + _proto_id = proto_id; + _skip_restore = false; + } + + public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; + public CheckpointRestoreCoordinator Checkpoint => _checkpoint; + public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; + + public void restore(Trackable trackable) + { + using (ops.init_scope()) + { + if (bind_project(trackable)) + { + var restore_ops = _restore_descendants(); + if(restore_ops is not null && restore_ops.Count > 0) + { + _checkpoint.new_restore_ops(restore_ops); + } + } + } + } + + /// + /// Set a checkpoint<->object correspondence. + /// + /// + /// + public bool bind_project(Trackable trackable) + { + _checkpoint.AllTrackables.Add(trackable); + _checkpoint.MatchedProtoIds.Add(_proto_id); + if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) + { + // skip the `logging.warning`. + return false; + } + else + { + _checkpoint.ObjectByProtoId[_proto_id] = trackable; + return true; + } + } + + public (List, Dictionary>, List, object?) gather_ops_or_named_saveables() + { + // skip the registered_saver + + if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) + { + return (new List(), new Dictionary>(), + new List(), null); + } + + var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); + + List existing_restore_ops; + List positions = new(); + Dictionary> named_saveables; + if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) + { + (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); + } + else if(saveable_factories.Count > 0) + { + (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); + } + else + { + throw new NotImplementedException(); + } + return (existing_restore_ops, named_saveables, positions, null); + } + + public CheckpointPosition create_child_position(int node_id) + { + return new CheckpointPosition(_checkpoint, node_id); + } + + public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, + int slot_variable_id, string slot_name) + { + //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); + + // TODO(Rinne): implement it. + return (null, null); + } + + /// + /// Creates a saveable using the _serialize_to_tensor method. + /// + /// + private (List, Dictionary>) _create_serialize_to_tensor_saveable( + IDictionary>> saveable_factories) + { + string suffix = SaveableCompat.get_saveable_name(this.Trackable); + suffix = suffix ?? ""; + var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; + + if (!tf.Context.executing_eagerly()) + { + throw new NotImplementedException("The restore under graph mode has not been implemented. " + + "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + + var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); + // skip the cache. + Dictionary> dict = new(); + dict[saveable_name] = saveable; + return (new List(), dict); + } + + private (List, Dictionary>) _create_saveables_by_attribute_name( + IDictionary>> saveable_factories) + { + // TODO(Rinne): implement it. + if(ObjectProto.Attributes is null) + { + return (new List(), new Dictionary>()); + } + + List existing_restore_ops = new(); + HashSet created_compat_names = new(); + Dictionary> named_saveables = new(); + foreach (var serialized_tensor in ObjectProto.Attributes) + { + Operation existing_op; + if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) + { + existing_op = null; + } + else + { + existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; + } + + if(existing_op is not null) + { + existing_restore_ops.Add(existing_op); + continue; + } + + if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) + { + continue; + } + + // TODO(Rinne): deal with cache. + + var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); + if(saveable is null) + { + _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List()).Add(serialized_tensor.Name); + continue; + } + named_saveables[serialized_tensor.CheckpointKey] = saveable; + } + return (existing_restore_ops, named_saveables); + } + + private Maybe _get_saveable_from_factory(IDictionary>> saveable_factories, + TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet created_compat_names) + { + var expected_factory_name = serialized_tensor.Name; + var factory_input_name = serialized_tensor.CheckpointKey; + + if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) + { + foreach(var item in saveable_factories) + { + var factory_name = item.Key; + var factory = item.Value; + if (expected_factory_name.StartsWith(factory_name)) + { + if(matched_factory is not null) + { + throw new ValueError($"Forward compatibility load error: Unable to load " + + "checkpoint saved in future version of TensorFlow. " + + "Please update your version of TensorFlow to the " + + "version in which the checkpoint was saved."); + } + } + matched_factory = factory; + factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; + created_compat_names.Add(factory_name); + } + } + return matched_factory(factory_input_name); + } + + private string _extract_saveable_name(string checkpoint_key) + { + var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; + return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); + } + + /// + /// Restore the bound Trackable and dependencies (may be deferred). + /// + private List _restore_descendants() + { + Queue<(CheckpointPosition, Trackable)> visit_queue = new(); + visit_queue.Enqueue((this, this.Trackable)); + List restore_ops = new(); + Dictionary> tensor_saveables = new(); + List positions = new(); + + CheckpointPosition current_position = null; + while (visit_queue.Count > 0) + { + current_position = visit_queue.Dequeue().Item1; + var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); + restore_ops.AddRange(new_restore_ops); + foreach(var item in new_tensor_saveables) + { + tensor_saveables.Add(item.Key, item.Value); + } + positions.AddRange(new_positions); + _queue_children_for_restoration(current_position, visit_queue); + _queue_slot_variables(current_position, visit_queue); + } + restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); + return restore_ops; + } + + private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) + { + var trackable = checkpoint_position.Trackable; + foreach(var child in checkpoint_position.ObjectProto.Children) + { + var child_position = checkpoint_position.create_child_position(child.NodeId); + var local_object = trackable._lookup_dependency(child.LocalName); + var child_proto = child_position.ObjectProto; + if(local_object is null) + { + if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) + { + trackable.DeferredDependencies.SetDefault(child.LocalName, new List()).Add(child_position); + } + } + else + { + if (child_position.bind_project(local_object)) + { + visit_queue.Enqueue((child_position, local_object)); + } + } + } + } + + private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) + { + var trackable = checkpoint_position.Trackable; + var checkpoint = checkpoint_position.Checkpoint; + if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) + { + checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); + foreach (var deferred_slot_restoration in positions) + { + var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( + trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, + deferred_slot_restoration.SlotName + ); + if(slot_variable_position is not null) + { + visit_queue.Enqueue((slot_variable_position, slot_variable)); + } + } + } + if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) + { + checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); + foreach (var slot_restoration in restorations) + { + if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) + { + throw new NotImplementedException(); + // TODO(Rinne); implement it. + } + else + { + Debug.Assert(trackable is BaseResourceVariable); + Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List()) + .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); + } + } + } + } + + private (List, Dictionary>, List, object?) _single_restore() + { + var trackable = this.Trackable; + trackable._maybe_initialize_trackable(); + if(_checkpoint.RestoreUid > trackable.UpdateUid) + { + var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); + trackable.UpdateUid = _checkpoint.RestoreUid; + return (restore_ops, tensor_saveables, positions, registered_savers); + } + else + { + return (new List(), new Dictionary>(), + new List(), null); + } + } +} + +public record class DeferredSlotVariableRestoration( + BaseResourceVariable OriginalVariable, + int SlotVariableId, + string SlotName +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs index cb3ea4d3c..2926f8e28 100644 --- a/src/TensorFlowNET.Core/Eager/execute.cs +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Eager { - internal class execute + internal static class execute { public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) { @@ -27,5 +27,9 @@ public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] i return tensors; } + public static bool must_record_gradient() + { + return false; + } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index bac9cedbf..a6720a5f3 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -13,8 +13,8 @@ namespace Tensorflow.Functions /// public class ConcreteFunction: Trackable { - FuncGraph func_graph; - ForwardBackwardCall forward_backward; + internal FuncGraph func_graph; + internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; @@ -23,6 +23,8 @@ public class ConcreteFunction: Trackable public Tensor[] Outputs; public Type ReturnType; public TensorSpec[] OutputStructure; + public IEnumerable ArgKeywords { get; set; } + public long NumPositionArgs { get; set; } public ConcreteFunction(string name) { @@ -163,6 +165,15 @@ public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs) return flat_outputs; } + public void AddTograph(Graph? g = null) + { + if(!tf.Context.executing_eagerly() && g is null) + { + g = ops.get_default_graph(); + } + // TODO(Rinne); complete it with `_delayed_rewrite_functions`. + } + ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { var functions = new FirstOrderTapeGradientFunctions(func_graph, false); diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs index 5f08702da..142b8b64e 100644 --- a/src/TensorFlowNET.Core/IO/gfile.cs +++ b/src/TensorFlowNET.Core/IO/gfile.cs @@ -16,8 +16,10 @@ limitations under the License. using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.IO { @@ -63,5 +65,15 @@ public string[] glob(string data_dir) dirs.AddRange(Directory.GetFiles(dir)); return dirs.ToArray(); } + + public string join(params string[] paths) + { + Debug.Assert(paths.Length >= 1); + if (paths[0].Substring(1).Contains("://")) + { + throw new NotImplementedException("The combination of urls has not been implemented."); + } + return Path.Combine(paths); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs index 4e190605c..f6087a43a 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -37,7 +37,16 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { - var axis = serializer.Deserialize(reader, typeof(long[])); + int[]? axis; + if(reader.ValueType == typeof(long)) + { + axis = new int[1]; + axis[0] = (int)serializer.Deserialize(reader, typeof(int)); + } + else + { + axis = serializer.Deserialize(reader, typeof(int[])) as int[]; + } if (axis is null) { throw new ValueError("Cannot deserialize 'null' to `Axis`."); diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs new file mode 100644 index 000000000..fce7bec58 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs @@ -0,0 +1,36 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedDTypeJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(TF_DataType); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); + token.WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + if (reader.ValueType == typeof(string)) + { + var str = (string)serializer.Deserialize(reader, typeof(string)); + return dtypes.tf_dtype_from_name(str); + } + else + { + return (TF_DataType)serializer.Deserialize(reader, typeof(int)); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs index 1ad19fc89..cfd8ee8f7 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs @@ -46,7 +46,16 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer { throw new ValueError("Cannot deserialize 'null' to `Shape`."); } - if(values.Length != 3) + if(values.Length == 1) + { + var array = values[0] as JArray; + if(array is null) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + values = array.ToObject(); + } + if (values.Length < 3) { throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); } @@ -54,19 +63,37 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer { throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); } - if (values[1] is not int) + int nodeIndex; + int tensorIndex; + if (values[1] is long) + { + nodeIndex = (int)(long)values[1]; + } + else if (values[1] is int) + { + nodeIndex = (int)values[1]; + } + else { throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); } - if (values[2] is not int) + if (values[2] is long) + { + tensorIndex = (int)(long)values[2]; + } + else if (values[1] is int) + { + tensorIndex = (int)values[2]; + } + else { throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); } return new NodeConfig() { Name = values[0] as string, - NodeIndex = (int)values[1], - TensorIndex = (int)values[2] + NodeIndex = nodeIndex, + TensorIndex = tensorIndex }; } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs index 300cb2f28..198662afe 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -51,10 +51,28 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { - var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; - if(dims is null) + long?[] dims; + try { - throw new ValueError("Cannot deserialize 'null' to `Shape`."); + dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + } + catch (JsonSerializationException ex) + { + if (reader.Value.Equals("class_name")) + { + reader.Read(); + reader.Read(); + reader.Read(); + dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + } + else + { + throw ex; + } + } + if (dims is null) + { + return null; } long[] convertedDims = new long[dims.Length]; for(int i = 0; i < dims.Length; i++) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 036291076..20a98e3d3 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -19,6 +19,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } + List Weights { get; } Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index cac19180f..934d3b151 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -1,8 +1,11 @@ using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; namespace Tensorflow.Keras.Saving { diff --git a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs index 4437ba0aa..9ff381299 100644 --- a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs +++ b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs @@ -3,6 +3,7 @@ using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Train; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow.ModelSaving { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 2b83dd1d1..4e9369a8b 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -71,6 +71,7 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell public List TrainableVariables => throw new NotImplementedException(); public List TrainableWeights => throw new NotImplementedException(); + public List Weights => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 956be96b5..26a9b5be8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -27189,8 +27189,33 @@ public static Tensor restore_slice(Tensor file_pattern, Tensor tensor_name, Tens /// /// Callers must ensure all the named tensors are indeed stored in the checkpoint. /// - public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") + public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + Dictionary attrs = new(); + attrs["dtypes"] = dtypes; + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + "RestoreV2", name, prefix, tensor_names, shape_and_slices + ) + { attrs = attrs }); + return result; + } + catch (Exception) + { + try + { + return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); + } + catch (Exception) + { + + } + } + } var dict = new Dictionary(); dict["prefix"] = prefix; dict["tensor_names"] = tensor_names; @@ -27202,6 +27227,22 @@ public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor sha return (tensors); } + public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) + { + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + object[] attrs = new object[] { "dtypes", dtypes }; + Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; + var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); + + if (execute.must_record_gradient()) + { + // TODO(Rinne); record the gradient + } + return result; + } + /// /// Reverses specific dimensions of a tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs index 35c5877f3..16e1bac47 100644 --- a/src/TensorFlowNET.Core/Operations/io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -62,6 +62,7 @@ public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, st public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { + // Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); return _op.outputs; diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 1b1fa0037..6ce7a0b00 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,8 +17,8 @@ limitations under the License. using System; using System.Linq; using Tensorflow.Framework; -using Tensorflow.ModelSaving; using Tensorflow.Train; +using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 5fe28c5d1..0f514b429 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -1,9 +1,13 @@ -namespace Tensorflow +using Newtonsoft.Json; +using Tensorflow.Keras.Common; + +namespace Tensorflow { /// /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. /// The enum values here are identical to corresponding values in types.proto. /// + [JsonConverter(typeof(CustomizedDTypeJsonConverter))] public enum TF_DataType { DtInvalid = 0, diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index deeb9e4b5..3563f91a0 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -159,7 +159,10 @@ public static TF_DataType tf_dtype_from_name(string name) "uint32" => TF_DataType.TF_UINT32, "int64" => TF_DataType.TF_INT64, "uint64" => TF_DataType.TF_UINT64, + "float16" => TF_DataType.TF_BFLOAT16, + "float32" => TF_DataType.TF_FLOAT, "single" => TF_DataType.TF_FLOAT, + "float64" => TF_DataType.TF_DOUBLE, "double" => TF_DataType.TF_DOUBLE, "complex" => TF_DataType.TF_COMPLEX128, "string" => TF_DataType.TF_STRING, diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 1309a6174..2fd0d1d83 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -39,6 +39,24 @@ public Tensor op _op = value; } } + public BaseResourceVariable variable + { + get + { + if (_op.TryGet(out var v)) + { + return v; + } + else + { + throw new TypeError("The _op is not a variable."); + } + } + set + { + _op = value; + } + } public SaveSpec[] specs; public string name; public string device; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs new file mode 100644 index 000000000..df9bdc1b5 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public record class LoadOptions + { + public bool allow_partial_checkpoint; + public string experimental_io_device; + public bool experimental_skip_checkpoint; + public VariablePolicy experimental_variable_policy; + + public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, + bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) + { + this.allow_partial_checkpoint = allow_partial_checkpoint; + this.experimental_io_device = experimental_io_device; + this.experimental_skip_checkpoint = experimental_skip_checkpoint; + this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs index fe0403c30..601882930 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -1,4 +1,5 @@ -using Tensorflow.Train; +using System; +using Tensorflow.Train; namespace Tensorflow; @@ -14,4 +15,10 @@ public class RevivedTypes // TODO: complete the implementation. return null; } + + public static Tuple> deserialize(object proto) + { + // TODO: complete the implementation. + return null; + } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs similarity index 83% rename from src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs rename to src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs index 45ebd884f..d42f52535 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.ModelSaving +namespace Tensorflow { /// /// Options for saving to SavedModel. @@ -35,7 +35,7 @@ private VariablePolicy(string policy) public bool save_variable_devices() { - return this != VariablePolicy.None; + return this != None; } /// @@ -45,14 +45,14 @@ public bool save_variable_devices() /// public static VariablePolicy from_obj(object obj) { - if (obj is null) return VariablePolicy.None; + if (obj is null) return None; if (obj is VariablePolicy) return (VariablePolicy)obj; var key = obj.ToString().ToLower(); return key switch { - null => VariablePolicy.None, - "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, - "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, + null => None, + "save_variable_devices" => SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") }; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 1be54287e..5752d7284 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -5,7 +5,6 @@ using Tensorflow.Checkpoint; using Tensorflow.Contexts; using Tensorflow.Functions; -using Tensorflow.ModelSaving; using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs new file mode 100644 index 000000000..341a12ab9 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Functions; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A class wraps a concrete function to handle different distributed contexts. + /// + internal class WrapperFunction: ConcreteFunction + { + public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) + { + this.forward_backward = concrete_function.forward_backward; + this.Outputs = concrete_function.Outputs; + this.ReturnType = concrete_function.ReturnType; + this.OutputStructure = concrete_function.OutputStructure; + this.ArgKeywords = concrete_function.ArgKeywords; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs new file mode 100644 index 000000000..5b482872d --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Functions; +using Tensorflow.Util; + +namespace Tensorflow.Training.Saving.SavedModel +{ + public static class function_deserialization + { + public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, + IDictionary concrete_functions) + { + var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; + concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); + concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; + + var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); + concrete_function.AddTograph(); + return concrete_function; + } + + private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) + { + // TODO(Rinne); revise the implementation. + return new FunctionSpec() + { + Fullargspec = function_spec_proto.Fullargspec, + IsMethod = function_spec_proto.IsMethod, + InputSignature = function_spec_proto.InputSignature, + JitCompile = function_spec_proto.JitCompile + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs new file mode 100644 index 000000000..da999b376 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -0,0 +1,641 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using System.Runtime.CompilerServices; +using Tensorflow.Variables; +using Tensorflow.Functions; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow +{ + /// + /// Helper class to load an object-based SavedModel. + /// + public partial class Loader + { + private pbc::RepeatedField _asset_file_def; + private Dictionary> _operation_attributes; + private SavedObjectGraph _proto; + private string _export_dir; + private CheckpointOptions _checkpoint_options; + private LoadOptions _save_options; + private IDictionary)> _node_filters; + private Dictionary? _node_path_to_id; + private List? _filtered_nodes; + private List _ordered_node_ids; + private Dictionary)> _loaded_nodes; + private List _nodes; + private Dictionary> _node_setters; + public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, + CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary)> filters) + { + var meta_graph = saved_model_proto.MetaGraphs[0]; + _asset_file_def = meta_graph.AssetFileDef; + _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); + _proto = object_graph_proto; + _export_dir = export_dir; + // TODO: `this._concrete_functions` and `this._restored_concrete_functions` + _checkpoint_options = ckpt_options; + _save_options = save_options; + + // TODO: `this._pretty_printer` + + _node_filters = filters; + _node_path_to_id = _convert_node_paths_to_ints(); + _loaded_nodes = new Dictionary)>(); + foreach(var filter in filters) + { + _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; + } + + _filtered_nodes = _retrieve_all_filtered_nodes(); + + _ordered_node_ids = _generate_ordered_node_ids(); + + _load_all(); + + + if (!save_options.experimental_skip_checkpoint) + { + _restore_checkpoint(); + } + foreach(var node in _nodes) + { + // skip the process of `CapturableResource`. + } + } + + /// + /// Maps all string node paths in node_filters to the int node ids. + /// + /// + private Dictionary? _convert_node_paths_to_ints() + { + if( _node_filters is null) + { + return null; + } + Dictionary path_to_int = new(); + foreach(var node_id in _node_filters.Keys) + { + int int_node_id; + var node_path = node_id.Split('.'); + if (node_path[0] != "root") + { + throw new ValueError($"When passing string identifiers to node_filters, the first name" + + $" must be root. Received {node_path[0]}."); + } + int_node_id = 0; + for(int i = 0; i < node_path.Length - 1; i++) + { + var name = node_path[i + 1]; + int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); + } + path_to_int[node_id] = int_node_id; + } + return path_to_int; + } + + private int _find_node_child(int node_id, string child_name, string path) + { + foreach(var refer in _proto.Nodes[node_id].Children) + { + if(refer.LocalName == child_name) + { + return refer.NodeId; + } + } + throw new ValueError($"Unable to find node {path}."); + } + + private List? _retrieve_all_filtered_nodes() + { + if(_node_filters is null) + { + return null; + } + + HashSet all_filtered_nodes = new(); + Queue nodes_to_visit = new Queue(_node_filters.Keys); + + while(nodes_to_visit.Count > 0) + { + var node_path = nodes_to_visit.Dequeue(); + var node_id = _node_path_to_id[node_path]; + if (all_filtered_nodes.Contains(node_id)) + { + continue; + } + all_filtered_nodes.Add(node_id); + Trackable node = null; + Action setter = null; + if(_loaded_nodes.TryGetValue(node_id, out var res)) + { + (node, setter) = res; + } + if(node is not null) + { + node._maybe_initialize_trackable(); + } + + foreach(var refer in _proto.Nodes[node_id].Children) + { + Trackable children_object = null; + if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) + { + children_object = result.Item1; + } + // See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. + if(children_object is null && node is not null) + { + children_object = node._lookup_dependency(refer.LocalName); + if(children_object is TrackableDataStructure) + { + // TODO: set setter as lambda. + + _loaded_nodes[refer.NodeId] = (children_object, setter); + } + } + string child_path = $"{node_path}.{refer.LocalName}"; + _node_path_to_id[child_path] = refer.NodeId; + nodes_to_visit.Enqueue(child_path); + } + } + + if (all_filtered_nodes.Contains(0)) + { + return null; + } + return all_filtered_nodes.ToList(); + } + + /// + /// Orders the node ids so that dependencies appear first. + /// + /// + private List _generate_ordered_node_ids() + { + List unordered_ids; + if(_filtered_nodes is null) + { + unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); + } + else + { + unordered_ids = new List(_filtered_nodes); + } + + Dictionary> dependency_map = new(); + foreach(var node_id in unordered_ids) + { + var deps = dependency_map.SetDefault(node_id, new List()); + if (_loaded_nodes.ContainsKey(node_id)) + { + continue; + } + var proto = _proto.Nodes[node_id]; + foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) + { + deps.Add(dep); + if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) + { + // TODO: add info with `_pretty_printer`. + throw new ValueError($"Unable to partially load SavedModel since the specified filter " + + $"does not include all required objects for loading (e.g. " + + $"variables used in functions or deserialization dependencies). " + + $"Please include this path in the filter: {dep}"); + } + } + int? prev_slot = null; + foreach(var slot_variable_proto in proto.SlotVariables) + { + var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; + // The optimizer and original variable must be created before the slot + // variable, since the slot variable is generated using the Optimizer's + // add_slot API. + var slot_deps = dependency_map[slot_variable_node_id]; + slot_deps.Add(node_id); + slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); + + if(prev_slot is not null) + { + slot_deps.Add(prev_slot.Value); + } + prev_slot = slot_variable_node_id; + } + } + try + { + return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable)); + } + catch (TrackableUtils.CyclicDependencyError ex) + { + throw new ValueError("Encountered a cycle in the deserialization dependencies" + + "in the SavedModel. This is extremely unexpected, please" + + "file a bug and make sure you are not manually modifying the SavedModel."); + } + } + + /// + /// Returns a dictionary of all dependencies of an object. + /// + /// + /// + private Dictionary, int> _get_node_dependencies(SavedObject proto) + { + Dictionary, int> dependencies = new(); + foreach(var refer in proto.Dependencies) + { + dependencies[refer.LocalName] = refer.NodeId; + } + if(proto.KindCase == SavedObject.KindOneofCase.Function) + { + var concreete_functions = proto.Function.ConcreteFunctions; + foreach(var fn_name in concreete_functions) + { + foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) + { + dependencies[bound_input] = bound_input; + } + } + } + else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) + { + var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; + foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) + { + dependencies[bound_input] = bound_input; + } + } + else if(proto.KindCase == SavedObject.KindOneofCase.Resource) + { + foreach(var child in proto.Children) + { + if(child.LocalName == "_create_resource") + { + dependencies["_create_resource"] = child.NodeId; + } + } + } + return dependencies; + } + + /// + /// Loads all nodes and functions from the SavedModel and their edges. + /// + private void _load_all() + { + _load_nodes(); + _load_edges(); + + _setup_remaining_functions(); + _load_checkpoint_save_and_restore_functions(); + } + + /// + /// Restores the checkpoint-related save/restore functions to all nodes. + /// + private void _load_checkpoint_save_and_restore_functions() + { + foreach(var (node_id, proto) in _iter_all_nodes()) + { + var node = get(node_id); + if(node is null) + { + // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. + continue; + } + if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) + { + // Restore Trackable serialize- and restore-from-tensor functions. + Debug.Assert(proto.SaveableObjects.Count == 1); + var saveable_object_proto = proto.SaveableObjects.Values.First(); + var save_fn_id = saveable_object_proto.SaveFunction; + var restore_fn_id = saveable_object_proto.RestoreFunction; + + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + else + { + // Restore legacy SaveableObject functions. + Dictionary saveable_fn_by_name = new(); + foreach(var item in proto.SaveableObjects) + { + var name = item.Key; + var saveable_object_proto = item.Value; + var save_fn_id = saveable_object_proto.SaveFunction; + var restore_fn_id = saveable_object_proto.RestoreFunction; + saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); + } + node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); + } + } + } + + /// + /// Load all saved objects. + /// + private void _load_nodes() + { + // `nodes` maps from node ids to recreated objects + // `node_setters` maps from node ids to setter functions + // (same signature as setattr) for setting children. + var (nodes, node_setters) = _initialize_loaded_nodes(); + + Dictionary + slot_variable_node_ids = new(); + + foreach(var (node_id, proto) in _iter_all_nodes()) + { + foreach(var slot_variable_proto in proto.SlotVariables) + { + var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; + slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); + } + } + + // Re-create everything. + foreach (var (node_id, proto) in _iter_all_nodes()) + { + if (nodes.ContainsKey(node_id)) + { + continue; + } + else if (slot_variable_node_ids.ContainsKey(node_id)) + { + // Use the public Optimizer interface when creating slot variables. + var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; + var optimizer_object = nodes[optimizer_node_id]; + var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; + + // TODO: implement it. + throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + + " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + else + { + // skip the function and concrete function. + if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) + { + nodes[node_id] = null; + node_setters[node_id] = null; + continue; + } + var (node, setter) = _recreate(proto, node_id, nodes); + nodes[node_id] = node; + node_setters[node_id] = setter; + } + } + + if (!nodes.ContainsKey(0)) + { + nodes[0] = _recreate_base_user_object().Item1; + } + _nodes = new List(); + for(int i = 0; i < _proto.Nodes.Count; i++) + { + _nodes.Add(nodes[i]); + } + _node_setters = node_setters; + } + + /// + /// Load state from checkpoint into the deserialized objects. + /// + private void _restore_checkpoint() + { + var variables_path = SavedModelUtils.get_variables_path(_export_dir); + var saver = new TrackableSaver(new ObjectGraphView(get(0))); + tf.device("CPU"); + saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + LoadStatus load_status; + if (_save_options.allow_partial_checkpoint) + { + load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); + load_status.assert_nontrivial_match(); + } + else + { + load_status = saver.restore(variables_path, _checkpoint_options); + load_status.assert_existing_objects_matched(); + } + var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; + + if (!tf.Context.executing_eagerly()) + { + throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + + "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + + /// + /// Adds edges from objects to other objects and functions. + /// + private void _load_edges() + { + foreach(var (node_id, object_proto) in _iter_all_nodes()) + { + _add_object_graph_edges(object_proto, node_id); + } + + if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) + { + var root = get(0); + foreach(var node_path in _node_filters.Keys) + { + var loaded_node = _nodes[_node_path_to_id[node_path]]; + + var path = node_path.Split('.'); + var current_node = root; + foreach(var name in path.Skip(1).Take(path.Length - 2)) + { + // `hasattr` and `setattr` is used here + throw new NotImplementedException(); + } + // `hasattr` and `setattr` is used here + throw new NotImplementedException(); + } + } + } + + private void _setup_remaining_functions() + { + // TODO: implement it with concrete functions. + } + + public Trackable get(int node_id) + { + return _nodes[node_id]; + } + + public Trackable get(string node_id) + { + return get(_node_path_to_id[node_id]); + } + + /// + /// Adds edges from an object to its children. + /// + /// + /// + private void _add_object_graph_edges(SavedObject proto, int node_id) + { + var obj = _nodes[node_id]; + var setter = _node_setters[node_id]; + + foreach(var refer in proto.Children) + { + if(obj is null) + { + // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. + continue; + } + setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); + // skip the process of "__call__" + } + } + + private (Dictionary, Dictionary>) _initialize_loaded_nodes() + { + Dictionary nodes = new(); + Dictionary> node_setters = new(); + foreach(var item in _loaded_nodes) + { + var node_id = item.Key; + var (node, setter) = item.Value; + nodes[node_id] = node; + node_setters[node_id] = setter; + } + return (nodes, node_setters); + } + + private IEnumerable<(int, SavedObject)> _iter_all_nodes() + { + foreach(var node_id in _ordered_node_ids) + { + yield return (node_id, _proto.Nodes[node_id]); + } + } + + private (Trackable, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) + { + // skip the registered classes. + + Dictionary, Trackable> dependencies = new(); + foreach(var item in _get_node_dependencies(proto)) + { + dependencies[item.Key] = nodes[item.Value]; + } + + return _recreate_default(proto, node_id, dependencies); + } + + /// + /// Creates a Python object from a SavedObject protocol buffer. + /// + /// + /// + /// + private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, Trackable> dependencies) + { + return proto.KindCase switch + { + SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), + SavedObject.KindOneofCase.Function => throw new NotImplementedException(), + SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), + SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), + SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() + }; + } + + private (Trackable, Action) _recreate_user_object(SavedUserObject? proto, int node_id) + { + // skip the check of proto identifier because of lack of property. + + var looked_up = RevivedTypes.deserialize(proto); + if(looked_up is null) + { + return _recreate_base_user_object(proto, node_id); + } + return (looked_up.Item1, looked_up.Item2); + } + + private (Trackable, Action) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) + { + return (new _UserObject(), setattr); + } + + private (BaseResourceVariable, Action) _recreate_variable(SavedVariable proto) + { + string name = proto.Name; + string dbg_name = !string.IsNullOrEmpty(name) ? name : ""; + + // TODO(Rinne): `validate_synchronization_aggregation_trainable` + + var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( + proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); + + var saved_device = proto.Device; + var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); + + if (load_with_device) + { + tf.device(saved_device); + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + } + else + { + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + } + } + + private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, + Dictionary, Trackable> dependencies) + { + throw new NotImplementedException(); + //var fn = function_deserialization.setup_bare_concrete_function(proto, ) + } + + // TODO: remove this to a common class. + public static Action setattr = (x, y, z) => + { + Debug.Assert(y is string); + var properties = x.GetType().GetProperties(); + foreach(var p in properties) + { + if((string)y == p.Name) + { + p.SetValue(x, z); + return; + } + } + // TODO(Rinne): check if the property has been set successfully. + //throw new ValueError($"Cannot find the property {y} of {x}."); + }; + + public class _UserObject: AutoTrackable + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs new file mode 100644 index 000000000..a92cb5509 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs @@ -0,0 +1,122 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow.Checkpoint; +using Tensorflow.Operations; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Loader + { + public static SavedModel parse_saved_model(string export_dir) + { + var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); + var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); + + SavedModel saved_model = new SavedModel(); + if (File.Exists(path_to_pb)) + { + byte[] file_content; + using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) + { + file_content = new byte[f.Length]; + Debug.Assert(f.Length <= int.MaxValue); + f.Read(file_content, 0, (int)f.Length); + } + // TODO: change to stream mode. + saved_model.MergeFrom(file_content); + return saved_model; + } + else if (File.Exists(path_to_pbtxt)) + { + throw new NotImplementedException(); + } + else + { + throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + + $"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); + } + } + + // TODO: revise the type of `tags` + public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) + { + return load_partial(export_dir, null, tags, options)["root"]; + } + + public static IDictionary load_partial(string export_dir, IDictionary)>? filters, object? tags = null, LoadOptions? options = null) + { + if (options is null) + { + options = new LoadOptions(); + } + if (tags is not null) + { + throw new NotImplementedException(); + } + var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); + + Trackable root = null; + Loader loader = null; + if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) + { + // skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` + var meta_graph_def = saved_model_proto.MetaGraphs[0]; + if (!BitConverter.IsLittleEndian) + { + SavedModelUtils.swap_function_tensor_content(meta_graph_def); + } + + var object_graph_proto = meta_graph_def.ObjectGraphDef; + var ckpt_options = new CheckpointOptions(options.experimental_io_device); + tf_with(ops.init_scope(), x => + { + loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); + root = loader.get(0); + // skip the assignment of `graph_debug_info`. + }); + // skip the assignment of `tensorflow_version` + // skip the assignment of `tensorflow_git_version` + // skip the process of `metrics`. + } + else + { + if(filters is not null && filters.Count > 0) + { + throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" + + " version) cannot be loaded with node filters."); + } + tf_with(ops.init_scope(), x => + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + }); + } + if(filters != null && filters.Count > 0) + { + return filters.Keys.ToDictionary(x => x, x => loader.get(x)); + } + else + { + var res = new Dictionary(); + res["root"] = root; + return res; + } + } + + public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) + { + var saved_model = parse_saved_model(export_dir); + + // TODO: implement debug info. + + return (saved_model, null); + } + + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index 94760e3df..4313920f5 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -6,7 +6,6 @@ using Google.Protobuf; using Tensorflow.Checkpoint; using Tensorflow.Functions; -using Tensorflow.ModelSaving; using Tensorflow.Train; using Tensorflow.Exceptions; using static Tensorflow.Binding; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs index 4cfe0b69b..47d8cbab9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.ModelSaving; namespace Tensorflow.Training.Saving.SavedModel { diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index a6e21e3e5..208311229 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -68,6 +68,34 @@ public static MySaveableObject[] validate_and_slice_inputs(IVariableV1[] names_t return saveables.ToArray(); } + public static MySaveableObject[] validate_and_slice_inputs(Dictionary names_to_saveables) + { + var saveables = new List(); + var seen_ops = new List(); + + foreach (var (name, op) in enumerate(names_to_saveables)) + { + foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + return saveables.ToArray(); + } + + public static MySaveableObject[] validate_and_slice_inputs(Dictionary names_to_saveables) + { + var saveables = new List(); + var seen_ops = new List(); + + foreach(var item in names_to_saveables.OrderBy(x => x.Key)) + { + foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) + { + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + } + return saveables.ToArray(); + } + private static void _add_saveable(List saveables, List seen_ops, T saveable) where T : MySaveableObject { if (seen_ops.Contains(saveable.op)) @@ -77,6 +105,15 @@ private static void _add_saveable(List saveables, List seen_ops, T seen_ops.Add(saveable.op); } + private static void _add_saveable(List saveables, List seen_ops, MySaveableObject saveable) + { + if (seen_ops.Contains(saveable.variable)) + throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); + + saveables.Add(saveable); + seen_ops.Add(saveable.variable); + } + /// /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// @@ -136,19 +173,20 @@ public static IEnumerable saveable_objects_for_op(Trackable ob { full_name = name + "_" + attr; } - if(factory.TryGet(out var variable)) + var op = factory(full_name); + if(op.TryGet(out var variable)) { - foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) + foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) { - yield return op; + yield return v; } } else { - var saveable = factory.GetValue(); - foreach (var op in saveable_objects_for_op(saveable, saveable.name)) + var saveable = op.GetValue(); + foreach (var v in saveable_objects_for_op(saveable, saveable.name)) { - yield return op; + yield return v; } } } @@ -214,20 +252,19 @@ public static Dictionary op_list_to_dict(IVariableV1[] op_list, return names_to_saveables; } - public static IDictionary> saveable_objects_from_trackable(Trackable obj) + public static IDictionary>> saveable_objects_from_trackable(Trackable obj) { // skip the process of type `PythonState` - if (trackable_has_serialize_to_tensor(obj)) + Maybe create_saveable(string name = "") { - var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. var tensor_dict = obj.serialize_to_tensors(); List specs = new(); List local_names = new(); string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; - foreach(var pair in tensor_dict) + foreach (var pair in tensor_dict) { var tensor_name = pair.Key; var maybe_tensor = pair.Value; @@ -235,9 +272,9 @@ public static IDictionary> string spec_name = name + TrackableUtils.escape_local_name(tensor_name); IDictionary internal_dict; - if(maybe_tensor.TryGet(out var tensor)) + if (maybe_tensor.TryGet(out var tensor)) { - internal_dict= new Dictionary(); + internal_dict = new Dictionary(); internal_dict[""] = tensor; } else @@ -245,13 +282,18 @@ public static IDictionary> internal_dict = maybe_tensor.GetValue>(); } - foreach(var item in internal_dict) + foreach (var item in internal_dict) { specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); } } - Dictionary> res = new(); - res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); + return new TrackableSaveable(obj, specs, name, local_names, prefix); + } + + if (trackable_has_serialize_to_tensor(obj)) + { + Dictionary>> res = new(); + res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; return res; } else @@ -333,6 +375,28 @@ public static Func return restored_ops; }; } + + /// + /// Returns a dict of SaveableObject factories generated from loaded fns. + /// + /// + /// + public static IDictionary>> recreate_saveable_objects( + IDictionary saveable_fn_by_name, IEnumerable? temp_session) + { + if (saveable_fn_by_name.Count > 0) + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + var res = new Dictionary>>(); + return res; + } + + public static Maybe create_saveable_object(string name, string key, Func> factory, + bool call_with_mapped_captures = false) + { + return factory(key); + } } public class SaveableCompatibilityConverter: Trackable diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 132571f2a..7c86a5802 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -20,8 +20,8 @@ limitations under the License. using System.Linq; using Tensorflow.Checkpoint; using Tensorflow.Keras.Saving.SavedModel; -using Tensorflow.ModelSaving; using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; using static Tensorflow.Binding; namespace Tensorflow.Train @@ -41,9 +41,10 @@ public static class Constants protected IDictionary _unconditional_dependency_names; protected IList _unconditional_checkpoint_dependencies; + protected Dictionary> _unconditional_deferred_dependencies; - protected IDictionary> _self_saveable_object_factories = - new Dictionary>(); + protected IDictionary>> _self_saveable_object_factories = + new Dictionary>>(); private bool _manual_tracking = true; private static Trackable _none = new AutoTrackable(); @@ -71,6 +72,18 @@ public virtual string ObjectIdentifier public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } + public Dictionary> DeferredDependencies => _unconditional_deferred_dependencies; + public IDictionary>> SelfSaveableObjectFactories + { + get + { + return _self_saveable_object_factories; + } + set + { + _self_saveable_object_factories = value; + } + } /// /// Restore-on-create for a variable be saved with this `Checkpointable`. @@ -136,9 +149,11 @@ public void _maybe_initialize_trackable() _self_update_uid = -1; _unconditional_checkpoint_dependencies = new List(); _unconditional_dependency_names = new Dictionary(); + _unconditional_deferred_dependencies = new Dictionary>(); } - public virtual IDictionary _trackable_children(SaveType save_type, IDictionary>? cache) + public virtual IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, + IDictionary>? cache = null) { _maybe_initialize_trackable(); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); @@ -174,10 +189,19 @@ public virtual Trackable _track_trackable(Trackable trackable, string name, bool /// public virtual void _handle_deferred_dependencies(string name, Trackable trackable) { - //_maybe_initialize_trackable(); - //trackable._maybe_initialize_trackable(); - - // TODO: complete the implementation. + _maybe_initialize_trackable(); + trackable._maybe_initialize_trackable(); + + if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) + { + _unconditional_deferred_dependencies.Remove(name); + foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) + { + checkpoint_position.restore(trackable); + } + } + + // TODO(Rinne): deal with `_self_name_based_restores` } public virtual Trackable? _lookup_dependency(string name) @@ -225,12 +249,19 @@ public virtual List export_to_saved_model_graph(IDictionary> gather_saveables_for_checkpoint() + public virtual IDictionary>> gather_saveables_for_checkpoint() { + Maybe create_saveable(string name = "") + { + throw new NotImplementedException(); + //return new TrackableSaveable(this, null, name, null, null); + } if (saveable_object_util.trackable_has_serialize_to_tensor(this)) { // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). - throw new NotImplementedException(); + Dictionary>> res = new(); + res[""] = create_saveable; + return res; } else { @@ -259,4 +290,6 @@ public virtual IDictionary _restore_from_tensors(IDictionary< } public record class TrackableReference(string Name, Trackable Refer); + + public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 390d95c75..05c513a83 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Checkpoint; using Tensorflow.Exceptions; using Tensorflow.Train; @@ -20,9 +21,9 @@ public CyclicDependencyError(IDictionary> leftover_dependency_map LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); } } - private static string _ESCAPE_CHAR = "."; - private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; - private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + internal static string _ESCAPE_CHAR = "."; + internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; public static string object_path_to_string(IEnumerable node_path_arr) { diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 4005d5640..9b8cfcb5f 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -5,9 +5,9 @@ using Tensorflow.Train; using static Tensorflow.Binding; using System.Collections.Generic; -using Tensorflow.ModelSaving; using System.Diagnostics; using Tensorflow.Checkpoint; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow { @@ -19,7 +19,11 @@ public class BaseResourceVariable : DisposableTrackableObject protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; - protected string handle_name => _handle_name; + public string handle_name + { + get { return _handle_name; } + set { _handle_name = value; } + } protected string _unique_id; public string UniqueId => _unique_id; @@ -289,10 +293,10 @@ public virtual void write_object_proto(SavedObject proto, SaveOptions options) resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); } - public override IDictionary> gather_saveables_for_checkpoint() + public override IDictionary>> gather_saveables_for_checkpoint() { - var res = new Dictionary>(); - res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + var res = new Dictionary>>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; return res; } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 1645d7130..3b1f1e968 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -238,5 +238,23 @@ public NDArray eval(Session session = null) { return _graph_element.eval(session); } + + public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( + VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) + { + if(aggregation is null) + { + aggregation = VariableAggregation.None; + } + if(synchronization is null) + { + synchronization = VariableSynchronization.Auto; + } + if (trainable is null) + { + trainable = synchronization != VariableSynchronization.OnRead; + } + return (synchronization.Value, aggregation.Value, trainable.Value); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs index b0d1b2b6b..f4407265c 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -24,10 +24,10 @@ public static Functional from_config(ModelConfig config) /// /// /// - static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config) + public static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config, Dictionary? created_layers = null) { // Layer instances created during the graph reconstruction process. - var created_layers = new Dictionary(); + created_layers = created_layers ?? new Dictionary(); var node_index_map = new Dictionary<(string, int), int>(); var node_count_by_layer = new Dictionary(); var unprocessed_nodes = new Dictionary(); @@ -88,12 +88,7 @@ static void process_layer(Dictionary created_layers, layer = created_layers[layer_name]; else { - layer = layer_data.ClassName switch - { - "InputLayer" => InputLayer.from_config(layer_data.Config), - "Dense" => Dense.from_config(layer_data.Config), - _ => throw new NotImplementedException("") - }; + layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); created_layers[layer_name] = layer; } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 44eaef534..33320101b 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -53,6 +53,11 @@ public Functional(Tensors inputs, Tensors outputs, string name = null) Inputs = inputs, Outputs = outputs }) + { + Initialize(inputs, outputs, name); + } + + internal void Initialize(Tensors inputs, Tensors outputs, string name = null) { _input_layers = new List(); _output_layers = new List(); @@ -70,7 +75,14 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs) this.inputs = inputs; this.outputs = outputs; built = true; - _buildInputShape = inputs.shape; + if(inputs.Length > 0) + { + _buildInputShape = inputs.shape; + } + else + { + _buildInputShape = new Saving.TensorShapeConfig(); + } if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index a2d212cb3..81fc26355 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; namespace Tensorflow.Keras.Engine { @@ -14,5 +15,30 @@ protected void StackLayers(params ILayer[] layers) public virtual Shape ComputeOutputShape(Shape input_shape) => throw new NotImplementedException(""); + + protected List _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) + { + List res = new(); + var nested_layers = _flatten_layers(false, false); + foreach (var layer in nested_layers) + { + if (layer is Layer l) + { + if (include_trainable == true && include_non_trainable == true) + { + res.AddRange(l.Variables); + } + else if (include_trainable == true && include_non_trainable == false) + { + res.AddRange(l.TrainableVariables); + } + else if(include_trainable == false && include_non_trainable == true) + { + res.AddRange(l.NonTrainableVariables); + } + } + } + return res; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index fc405d872..ed5c2de0a 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -12,7 +12,7 @@ public abstract partial class Layer public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; - public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; + public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 31b37d681..3934950bd 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -14,6 +14,7 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Linq; @@ -66,16 +67,74 @@ public abstract partial class Layer : AutoTrackable, ILayer public bool SupportsMasking { get; set; } protected List _trainable_weights; - public virtual List TrainableVariables => _trainable_weights; + public virtual List TrainableVariables => TrainableWeights; protected List _non_trainable_weights; - public List non_trainable_variables => _non_trainable_weights; + public List NonTrainableVariables => NonTrainableWeights; + public List Variables => Weights; + + public virtual List TrainableWeights + { + get + { + if (!this.Trainable) + { + return new List(); + } + var children_weights = _gather_children_variables(true); + return children_weights.Concat(_trainable_weights).Distinct().ToList(); + } + } + + public virtual List NonTrainableWeights + { + get + { + if (!this.Trainable) + { + var children_weights = _gather_children_variables(true, true); + return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); + } + else + { + var children_weights = _gather_children_variables(include_non_trainable: true); + return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); + } + } + } + + public virtual List Weights + { + get + { + return TrainableWeights.Concat(NonTrainableWeights).ToList(); + } + set + { + if (Weights.Count() != value.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(value)}, but the layer was " + + $"expecting {len(Weights)} weights."); + foreach (var (this_w, v_w) in zip(Weights, value)) + this_w.assign(v_w, read_value: true); + } + } protected int id; public int Id => id; protected string name; protected string base_name; - public string Name => name; + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } protected bool computePreviousMask; protected List updates; @@ -85,10 +144,11 @@ public abstract partial class Layer : AutoTrackable, ILayer List inboundNodes; public List InboundNodes => inboundNodes; - List outboundNodes; public List OutboundNodes => outboundNodes; + public JObject SerializedAttributes { get; set; } + ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; public Tensor[] input @@ -117,6 +177,11 @@ public Shape OutputShape protected List _self_tracked_trackables; public Layer(LayerArgs args) + { + Initialize(args); + } + + internal virtual void Initialize(LayerArgs args) { this.args = args; // A stateful layer is a layer whose updates are run during inference too, @@ -273,46 +338,9 @@ protected virtual void _init_set_name(string name, bool zero_based = true) public int count_params() { if (Trainable) - return layer_utils.count_params(this, weights); + return layer_utils.count_params(this, Weights); return 0; } - List ILayer.TrainableWeights - { - get - { - return _trainable_weights; - } - } - - List ILayer.NonTrainableWeights - { - get - { - return _non_trainable_weights; - } - } - - public List weights - { - get - { - var weights = new List(); - weights.AddRange(_trainable_weights); - weights.AddRange(_non_trainable_weights); - return weights; - } - set - { - if (weights.Count() != value.Count()) throw new ValueError( - $"You called `set_weights` on layer \"{this.name}\"" + - $"with a weight list of length {len(value)}, but the layer was " + - $"expecting {len(weights)} weights."); - foreach (var (this_w, v_w) in zip(weights, value)) - this_w.assign(v_w, read_value: true); - } - } - - public List Variables => weights; public virtual IKerasConfig get_config() => args; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index a1e891f98..a3956cccc 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -33,7 +33,7 @@ public void save(string filepath, { using (SharedObjectSavingScope.Enter()) { - KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index dd3e11a27..bbc6e8293 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -36,6 +36,8 @@ public partial class Model : Layer, IModel IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; + + public bool IsGraphNetwork => _is_graph_network; public OptimizerV2 Optimizer { @@ -49,6 +51,12 @@ public Model(ModelArgs args) _init_batch_counters(); } + internal override void Initialize(LayerArgs args) + { + _init_batch_counters(); + base.Initialize(args); + } + void _configure_steps_per_execution(int steps_per_execution) { _steps_per_execution = tf.Variable(steps_per_execution, @@ -81,10 +89,11 @@ void _init_batch_counters() public override List Layers => _flatten_layers(recursive: false, include_self: false).ToList(); - public override List TrainableVariables + public override List TrainableWeights { get { + // skip the assertion of weights created. var variables = new List(); if (!Trainable) @@ -95,18 +104,40 @@ public override List TrainableVariables foreach (var trackable_obj in _self_tracked_trackables) { if (trackable_obj.Trainable) - variables.AddRange(trackable_obj.TrainableVariables); + variables.AddRange(trackable_obj.TrainableWeights); } - foreach (var layer in _self_tracked_trackables) + variables.AddRange(_trainable_weights); + + return variables.Distinct().ToList(); + } + } + + public override List NonTrainableWeights + { + get + { + // skip the assertion of weights created. + var variables = new List(); + + foreach (var trackable_obj in _self_tracked_trackables) { - if (layer.Trainable) - variables.AddRange(layer.TrainableVariables); + variables.AddRange(trackable_obj.NonTrainableWeights); } - // variables.AddRange(_trainable_weights); + if (!Trainable) + { + var trainable_variables = new List(); + foreach (var trackable_obj in _self_tracked_trackables) + { + variables.AddRange(trackable_obj.TrainableWeights); + } + variables.AddRange(trainable_variables); + variables.AddRange(_trainable_weights); + variables.AddRange(_non_trainable_weights); + } - return variables; + return variables.Distinct().ToList(); } } diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 4d87659bd..69665388b 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -44,8 +44,6 @@ public Sequential(SequentialArgs args) : base(args.Inputs, args.Outputs, name: args.Name) { this.args = args; - if (args.Layers == null) - args.Layers = new List(); // SupportsMasking = true; _compute_output_and_mask_jointly = true; _auto_track_sub_layers = false; @@ -54,10 +52,17 @@ public Sequential(SequentialArgs args) _created_nodes = new List(); // Add to the model any layers passed to the constructor. - if (args.Layers != null) + if (args.Layers is not null) { - foreach (var layer in args.Layers) - add(layer); + InitLayers(args.Layers); + } + } + + public void InitLayers(IEnumerable layers) + { + foreach(var layer in layers) + { + add(layer); } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 45f64720f..9cb5b7565 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -25,8 +25,7 @@ public override void build(Shape input_shape) { throw new ValueError("Alpha must be a number greater than 0."); } - _buildInputShape = input_shape; - built = true; + base.build(input_shape); } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index 2fd2caee1..981f96f0b 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -14,8 +14,7 @@ public Exponential(LayerArgs args) : base(args) } public override void build(Shape input_shape) { - _buildInputShape = input_shape; - built = true; + base.build(input_shape); } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 1ef8d0e58..9b5bc0e66 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -19,8 +19,7 @@ public override void build(Shape input_shape) { if ( alpha < 0f ) { throw new ValueError("Alpha must be a number greater than 0."); } - _buildInputShape = input_shape; - built = true; + base.build(input_shape); } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index ca8007d09..56fde9f2c 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -85,10 +85,5 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train return outputs; } - - public static Dense from_config(LayerArgs args) - { - return new Dense(args as DenseArgs); - } } } diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 03b4b742a..a44c0bded 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -102,11 +102,6 @@ public InputLayer(InputLayerArgs args) : name: Name); } - public static InputLayer from_config(LayerArgs args) - { - return new InputLayer(args as InputLayerArgs); - } - public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); } } diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs index 1dfc39c49..435eebd48 100644 --- a/src/TensorFlowNET.Keras/Metrics/Metric.cs +++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs @@ -56,7 +56,7 @@ public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_w public virtual void reset_states() { - foreach (var v in weights) + foreach (var v in Weights) v.assign(0); } diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs index 73b77bc42..6597f5cdc 100644 --- a/src/TensorFlowNET.Keras/Models/ModelsApi.cs +++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs @@ -4,6 +4,7 @@ using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Saving.SavedModel; using ThirdParty.Tensorflow.Python.Keras.Protobuf; namespace Tensorflow.Keras.Models @@ -13,20 +14,9 @@ public class ModelsApi public Functional from_config(ModelConfig config) => Functional.from_config(config); - public void load_model(string filepath, bool compile = true) + public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) { - var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); - var saved_mode = SavedModel.Parser.ParseFrom(bytes); - - var meta_graph_def = saved_mode.MetaGraphs[0]; - var object_graph_def = meta_graph_def.ObjectGraphDef; - - bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); - var metadata = SavedMetadata.Parser.ParseFrom(bytes); - - // Recreate layers and metrics using the info stored in the metadata. - var keras_loader = new KerasObjectLoader(metadata, object_graph_def); - keras_loader.load_layers(compile: compile); + return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index fc8cab0c1..fffc2bac0 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -1,12 +1,24 @@ using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; using System.Linq; +using System.Reflection; using System.Text.RegularExpressions; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; +using Tensorflow.Training; using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.ApiDef.Types; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving { public class KerasObjectLoader { - SavedMetadata _metadata; - SavedObjectGraph _proto; - Dictionary _node_paths = new Dictionary(); - Dictionary model_layer_dependencies = new Dictionary(); - List _traversed_nodes_from_config = new List(); + private static readonly IDictionary PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; + private SavedMetadata _metadata; + private SavedObjectGraph _proto; + private Dictionary _node_paths = new Dictionary(); + private Dictionary model_layer_ids_dependencies = new Dictionary(); + private Dictionary model_layer_dependencies = new Dictionary(); + private List _traversed_nodes_from_config = new List(); + private Dictionary)> loaded_nodes; + private List _models_to_reconstruct; + public Dictionary)> LoadedNodes => loaded_nodes; + + static KerasObjectLoader() + { + PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; + } public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) { _metadata = metadata; _proto = object_graph_def; _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); + _models_to_reconstruct = new List(); + loaded_nodes = new Dictionary)>(); } /// @@ -42,15 +66,255 @@ public void load_layers(bool compile = true) continue; } - _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); + loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); + } + foreach(var node_metadata in metric_list) + { + try + { + if (node_metadata.Identifier.Equals("_tf_keras_metric")) + { + continue; + } + loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, + node_metadata.Metadata); + } + catch(ValueError e) + { + if (compile) + { + throw e; + } + // TODO: add logging.warning. + } + } + } + + public string get_path(int node_id) + { + return _node_paths[node_id]; + } + + /// + /// Finish setting up Keras objects. + /// + /// This function is executed after all objects and functions have been created. + /// Call functions and losses are attached to each layer, and once all layers + /// have been fully set up, graph networks are initialized. + /// + /// Subclassed models that are revived from the SavedModel are treated like + /// layers, and have their call/loss functions attached here. + /// + public void finalize_objects() + { + List layers_revived_from_config = new(); + List layers_revived_from_saved_model = new(); + foreach(var item in loaded_nodes) + { + var node_id = item.Key; + var node = item.Value.Item1; + if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) + { + continue; + } + + _unblock_model_reconstruction(node_id, node as Layer); + + if(node is InputLayer or Metric) + { + continue; + } + + // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. + layers_revived_from_config.Add(node as Layer); + } + + _finalize_saved_model_layers(layers_revived_from_saved_model); + _finalize_config_layers(layers_revived_from_config); + + _reconstruct_all_models(); + } + + private void _reconstruct_all_models() + { + HashSet all_initialized_models = new(); + for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) + { + int model_id = _models_to_reconstruct[i]; + all_initialized_models.Add(model_id); + var (model, layers) = model_layer_dependencies[model_id]; + _reconstruct_model(model_id, model, layers.ToList()); + _finalize_config_layers(new List() { model }); + } + + Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); + } + + private void _reconstruct_model(int model_id, Model model, List layers) + { + var config = JsonConvert.DeserializeObject(_metadata.Nodes[model_id].Metadata)["config"]; + + if(model.input is not null && model.input.Length > 0) + { + + } + else if(model is Sequential s) + { + if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) + { + if (config["layers"][0]["class_name"].ToObject() == "InputLayer") + { + layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject())); + } + else if (config["layers"][0]["config"]["batch_input_shape"] is not null) + { + // TODO(Rinne): implement it + } + } + + // `model.__init__(layers, config["name"])` + s.InitLayers(layers); + s.Name = config["name"].ToObject(); + if(s.input is null || s.input.Length == 0) + { + var first_layer = _get_child_layer_node_ids(model_id)[0]; + var input_specs = _infer_inputs(first_layer); + var input_shapes = _infer_inputs(first_layer, true); + // `model._set_inputs(input_specs)` + + // skip the check of input_specs is Dictionary + if (!s.Built) + { + s.build(input_shapes); + } + } + } + else + { + // skip the parameter `created_layers`. + var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), + layers.ToDictionary(x => x.Name, x => x as ILayer)); + // skip the `model.__init__` + (model as Functional).Initialize(inputs, outputs, config["name"].ToObject()); + (model as Functional).connect_ancillary_layers(created_layers); + } + + _set_network_attributes_from_metadata(model); + _unblock_model_reconstruction(model_id, model); + } + + private void _set_network_attributes_from_metadata(Model revived_object) + { + // TODO: implement it. + } + + /// + /// Runs the final steps of loading Keras Layers from config. + /// + /// + private void _finalize_config_layers(List layers) + { + foreach(var layer in layers) + { + if (_is_graph_network(layer)) + { + _restore_layer_unconditional_losses(layer); + } + _restore_layer_activation_loss(layer); + _restore_layer_metrics(layer); + + // TODO(Rinne): deal with RNN. + } + } + + /// + /// Runs the final steps of loading Keras Layers from SavedModel. + /// + /// + private void _finalize_saved_model_layers(List layers) + { + foreach(var layer in layers) + { + // TODO(Rinne): deal with `RevivedNetwork`. + + _restore_layer_unconditional_losses(layer); + _restore_layer_activation_loss(layer); + _restore_layer_metrics(layer); + } + } + + private void _restore_layer_unconditional_losses(Layer layer) + { + // TODO(Rinne): implement it. + } + + private void _restore_layer_activation_loss(Layer layer) + { + // TODO(Rinne): implement it. + } + + private void _restore_layer_metrics(Layer layer) + { + // TODO(Rinne): implement it. + } + + /// + /// Removes layer from blocking model reconstruction. + /// + /// + /// + private void _unblock_model_reconstruction(int layer_id, Layer layer) + { + foreach(var depencency in model_layer_ids_dependencies) + { + var layer_ids = depencency.Value.Item2; + var layers = model_layer_dependencies.SetDefault(depencency.Key, + (depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; + if (!layer_ids.Contains(layer_id)) + { + continue; + } + layers[Array.IndexOf(layer_ids, layer_id)] = layer; + if (layers.All(x => x is not null)) + { + _models_to_reconstruct.Add(depencency.Key); + } } } - void _load_layer(int node_id, string identifier, string metadata_json) + private (Trackable, Action) _load_layer(int node_id, string identifier, string metadata_json) { - metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); var metadata = JsonConvert.DeserializeObject(metadata_json); - _revive_from_config(identifier, metadata, node_id); + + if (loaded_nodes.ContainsKey(node_id)) + { + var (node, setter) = loaded_nodes[node_id]; + + _maybe_add_serialized_attributes(node as Layer, metadata); + var config = metadata.Config; + if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) + { + Debug.Assert(node is Model); + var child_nodes = _get_child_layer_node_ids(node_id); + model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); + if(child_nodes is null || child_nodes.Length == 0) + { + _models_to_reconstruct.Add(node_id); + } + } + return (node, setter); + } + else + { + var (obj, setter) = _revive_from_config(identifier, metadata, node_id); + if (obj is null) + { + (obj, setter) = _revive_custom_object(identifier, metadata); + } + Debug.Assert(obj is Layer); + _maybe_add_serialized_attributes(obj as Layer, metadata); + return (obj, setter); + } } /// @@ -59,11 +323,34 @@ void _load_layer(int node_id, string identifier, string metadata_json) /// /// /// - void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) + private (Trackable, Action) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) { - var obj = _revive_graph_network(identifier, metadata, node_id); - obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); + Trackable obj; + if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) + { + // TODO(Rinne): implement it. + return (null, null); + //throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + else + { + obj = _revive_graph_network(identifier, metadata, node_id); + obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); + } + + if(obj is null) + { + return (null, null); + } + var setter = _config_node_setter(_revive_setter); _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); + return (obj, setter); + } + + private (Trackable, Action) _revive_custom_object(string identifier, KerasMetaData metadata) + { + // TODO(Rinne): implement it. + throw new NotImplementedException(); } Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) @@ -71,6 +358,12 @@ Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_ var config = metadata.Config; var class_name = metadata.ClassName; Model model = null; + + if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") + { + return null; + } + if (class_name == "Sequential") { model = new Sequential(new SequentialArgs @@ -78,34 +371,82 @@ Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_ Name = config.GetValue("name").ToString() }); } - else if (class_name == "Functional") + else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) { - throw new NotImplementedException(""); + model = new Sequential(new SequentialArgs + { + Name = class_name + }); + } + else + { + model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject()); } - - if (!metadata.IsGraphNetwork) - return null; // Record this model and its layers. This will later be used to reconstruct // the model. var layers = _get_child_layer_node_ids(node_id); - model_layer_dependencies[node_id] = (model, layers); + model_layer_ids_dependencies[node_id] = (model, layers); + if(layers is null || layers.Length == 0) + { + _models_to_reconstruct.Add(node_id); + } return model; } - Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) + Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) { var config = metadata.Config; var class_name = metadata.ClassName; var shared_object_id = metadata.SharedObjectId; var must_restore_from_config = metadata.MustRestoreFromConfig; - var obj = class_name switch - { - "Resizing" => Resizing.from_config(config), - _ => throw new NotImplementedException("") - }; + + var obj = generic_utils.deserialize_keras_object(class_name, config); + + obj.Name = metadata.Name; + // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` + + var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); - return null; + if (!built) + { + return null; + } + return obj; + } + + private void _revive_setter(object layer, object name, object value) + { + Debug.Assert(name is string); + Debug.Assert(layer is Layer); + if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) + { + if(value is Trackable) + { + (layer as Layer)._track_trackable(value as Trackable, name as string); + } + if((layer as Layer).SerializedAttributes is null) + { + (layer as Layer).SerializedAttributes = new JObject(); + } + (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); + } + else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) + { + (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); + } + else + { + var properties = layer.GetType().GetProperties(); + foreach(var p in properties) + { + if(p.Name == name as string && p.GetValue(layer) is not null) + { + return; + } + } + Loader.setattr(layer, name, value); + } } /// @@ -143,34 +484,186 @@ int[] _get_child_layer_node_ids(int node_id) /// /// /// - void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) + void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) { if (_traversed_nodes_from_config.Contains(node_id)) return; var parent_path = _node_paths[node_id]; _traversed_nodes_from_config.Add(node_id); - if (!obj.Built) + obj._maybe_initialize_trackable(); + + if(obj is Layer layer && !layer.Built) { - var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); - var metadata = JsonConvert.DeserializeObject(metadata_json); - _try_build_layer(obj, node_id, metadata.BuildInputShape); + var metadata = JsonConvert.DeserializeObject(_metadata.Nodes[node_id].Metadata); + _try_build_layer(layer, node_id, metadata.BuildInputShape); + } + + + List<(Trackable, int, string)> children = new(); + foreach(var refer in proto.Children) + { + var obj_child = obj._lookup_dependency(refer.LocalName); + children.Add((obj_child, refer.NodeId, refer.LocalName)); + } + + var metric_list_node_id = _search_for_child_node(node_id, new string[] { + Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" + }); + if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) + { + var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); + foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) + { + if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) + { + var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; + children.Add((metric as Metric, refer.NodeId, metric_path)); + } + } + } + + foreach(var (obj_child, child_id, child_name) in children) + { + if(obj_child is null) + { + continue; + } + var child_proto = _proto.Nodes[child_id]; + + // skip the check for registered identifier + + Action setter; + if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) + { + setter = _revive_setter; + } + else + { + setter = Loader.setattr; + } + + if (loaded_nodes.ContainsKey(child_id)) + { + // skip the logging.warning + continue; + } + + if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) + { + (obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; + } + + if(obj_child is TrackableDataStructure) + { + setter = (x, y, z) => { }; + } + + var child_path = $"{parent_path}.{child_name}"; + _node_paths[child_id] = child_path; + _add_children_recreated_from_config(obj_child, child_proto, child_id); + loaded_nodes[child_id] = (obj_child, setter); } } - bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) + private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) { if (obj.Built) return true; + if(build_input_shape is null) + { + build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); + } + + if(build_input_shape is not null) + { + obj.build(build_input_shape); + // In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. + // On the one hand, C# does not support call a method from specified parent class. + // On the other hand, currently All class derived from Layer call `Layer.Build` or + // move the implementation of `Layer.build` to its own `build` method. + // Therefore we do not call it here. + // However, it's still quite risky once in the future a certain class derived from + // `Layer` does not call `Layer.build`. + + return true; + } + return false; } - bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) + /// + /// Infers input shape of layer from SavedModel functions. + /// + /// + /// + /// + private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) { - if (obj.Built) - return true; + var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); + if(call_fn_id is null) + { + return null; + } + var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; + if(concrete_functions is null) + { + return null; + } + var call_fn_name = concrete_functions[0]; + var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + + private int? _search_for_child_node(int parent_id, IEnumerable path_to_child) + { + if(path_to_child is null || path_to_child.Count() == 0) + { + return parent_id; + } + + foreach(var child in _proto.Nodes[parent_id].Children) + { + if(child.LocalName == path_to_child.First()) + { + return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); + } + } + return null; + } + + private bool _is_graph_network(Layer layer) + { + // TODO: deal with `RevivedLayer` + if(layer is Functional) + { + return (layer as Functional).IsGraphNetwork || layer is Sequential; + } return false; } + + private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) + { + // TODO: deal with `RevivedLayer` + } + + /// + /// Creates edges for nodes that are recreated from config. + /// + /// + private Action _config_node_setter(Action setter) + { + void setattr_wrapper(object obj, object name, object value) + { + Debug.Assert(obj is Trackable); + Debug.Assert(name is string); + if((obj as Trackable)._lookup_dependency(name as string) is null) + { + setter(obj, name, value); + } + } + return setattr_wrapper; + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index c7b7e52f4..220eae4b4 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, + public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, SaveOptions? options, bool save_traces = true) { if (!overwrite && File.Exists(filepath)) @@ -95,7 +95,7 @@ public static SavedMetadata generate_keras_metadata(IList saved_nodes BadConsumers = { } }, Identifier = layer.ObjectIdentifier, - Metadata = layer.TrackingMetadata + Metadata = layer.GetTrackingMetadata() }; metadata.Nodes.Add(saved_object); @@ -130,7 +130,7 @@ public static IDictionary wrap_layer_objects(Layer layer, IDi if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); })); - var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs new file mode 100644 index 000000000..abb2012f8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs @@ -0,0 +1,96 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class KerasLoadModelUtils + { + /// + /// Corresponding to keras/saving/save.py/load_model + /// + /// + /// + /// + /// + /// + public static Trackable load_model(string filepath, IDictionary? custom_objects = null, + bool compile = true, LoadOptions? options = null) + { + using (SharedObjectSavingScope.Enter()) + { + using (LoadContext.load_context(options)) + { + if (!File.Exists(filepath) && !Directory.Exists(filepath)) + { + throw new IOException($"No file or directory found at {filepath}."); + } + if (Directory.Exists(filepath)) + { + return load(filepath, compile, options); + } + else + { + throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); + } + } + } + } + + private static Trackable load(string path, bool compile = true, LoadOptions? options = null) + { + SavedMetadata metadata = new SavedMetadata(); + var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; + var object_graph_def = meta_graph_def.ObjectGraphDef; + string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); + if (File.Exists(path_to_metadata_pb)) + { + metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); + } + else + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + + if (metadata.Nodes is null || metadata.Nodes.Count == 0) + { + return Loader.load(path, options: options) as Model; + } + + var keras_loader = new KerasObjectLoader(metadata, object_graph_def); + keras_loader.load_layers(compile: compile); + + Dictionary)> nodes_to_load = new(); + nodes_to_load["root"] = (null, null); + foreach(var item in keras_loader.LoadedNodes) + { + nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; + } + var loaded = Loader.load_partial(path, nodes_to_load, options); + + keras_loader.finalize_objects(); + // keras_loader.del_tracking(); + + var model = loaded["root"]; + + if(model is Model && compile) + { + // TODO(Rinne): implement it. + } + + if (!tf.Context.executing_eagerly()) + { + // TODO(Rinne): implement it. + } + + return model; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs new file mode 100644 index 000000000..11b1201d0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + // TODO: remove this class to common project. + public class ContextHandler: IDisposable + { + public Action DisposeCallBack { get; set; } + public void Dispose() + { + DisposeCallBack.Invoke(true); + } + } + public class LoadContext + { + private bool _entered_load_context; + private LoadOptions? _load_options; + private static ThreadLocal _load_context = new(); + private LoadContext() + { + _entered_load_context = false; + _load_options = null; + } + + public void set_load_options(LoadOptions load_options) + { + _load_options = load_options; + _entered_load_context = true; + } + + private void clear_load_options() + { + _load_options = null; + _entered_load_context = false; + } + + private LoadOptions? load_options() + { + return _load_options; + } + + public static ContextHandler load_context(LoadOptions? load_options) + { + if(_load_context.Value is null) + { + _load_context.Value = new LoadContext(); + } + _load_context.Value.set_load_options(load_options); + return new ContextHandler() + { + DisposeCallBack = _ => _load_context.Value.clear_load_options() + }; + } + + public static LoadOptions? get_load_option() + { + return _load_context.Value.load_options(); + } + + public static bool in_load_context() + { + return _load_context.Value._entered_load_context; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 730a33e3e..03acce0ca 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -19,15 +19,21 @@ limitations under the License. using System; using System.Collections; using System.Collections.Generic; +using System.Data; using System.Diagnostics; using System.Linq; +using System.Reflection; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Saving; +using Tensorflow.Train; namespace Tensorflow.Keras.Utils { public class generic_utils { + private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; /// /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. /// @@ -51,6 +57,58 @@ public static JObject serialize_keras_object(IKerasConfigable instance) return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); } + public static Layer deserialize_keras_object(string class_name, JToken config) + { + var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) + .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); + var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); + var args = deserializationGenericMethod.Invoke(config, null); + var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + Debug.Assert(layer is Layer); + return layer as Layer; + } + + public static Layer deserialize_keras_object(string class_name, LayerArgs args) + { + var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + Debug.Assert(layer is Layer); + return layer as Layer; + } + + public static LayerArgs deserialize_layer_args(string class_name, JToken config) + { + var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) + .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); + var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); + var args = deserializationGenericMethod.Invoke(config, null); + Debug.Assert(args is LayerArgs); + return args as LayerArgs; + } + + public static ModelConfig deserialize_model_config(JToken json) + { + ModelConfig config = new ModelConfig(); + config.Name = json["name"].ToObject(); + config.Layers = new List(); + var layersToken = json["layers"]; + foreach (var token in layersToken) + { + var args = deserialize_layer_args(token["class_name"].ToObject(), token["config"]); + config.Layers.Add(new LayerConfig() + { + Config = args, + Name = token["name"].ToObject(), + ClassName = token["class_name"].ToObject(), + InboundNodes = token["inbound_nodes"].ToObject>() + }); + } + config.InputLayers = json["input_layers"].ToObject>(); + config.OutputLayers = json["output_layers"].ToObject>(); + return config; + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => @@ -60,5 +118,15 @@ public static string to_snake_case(string name) x.ToString(); })).ToLower(); } + + /// + /// Determines whether config appears to be a valid layer config. + /// + /// + /// + public static bool validate_config(JObject config) + { + return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); + } } } diff --git a/src/TensorFlowNET.Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Keras/Utils/layer_utils.cs index 3c38a6d1b..07d9f685e 100644 --- a/src/TensorFlowNET.Keras/Utils/layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/layer_utils.cs @@ -104,7 +104,7 @@ public static void print_summary(Model model, int line_length = -1, float[] posi } var trainable_count = count_params(model, model.TrainableVariables); - var non_trainable_count = count_params(model, model.non_trainable_variables); + var non_trainable_count = count_params(model, model.NonTrainableVariables); print($"Total params: {trainable_count + non_trainable_count}"); print($"Trainable params: {trainable_count}"); diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy new file mode 100644 index 000000000..b5a8f8b32 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb new file mode 100644 index 000000000..b62a57c3d Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb new file mode 100644 index 000000000..e1aab781a --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb @@ -0,0 +1,9 @@ + +�$root"_tf_keras_network*�${"name": "model", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": []}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]]}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}}}2 +� root.layer-0"_tf_keras_input_layer*�{"class_name": "InputLayer", "name": "input_1", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}2 +� root.layer-1"_tf_keras_layer*�{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 14}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}}2 +�root.layer_with_weights-0"_tf_keras_layer*�{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 784}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 +�root.layer_with_weights-1"_tf_keras_layer*�{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 100}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 100]}}2 +� root.layer-4"_tf_keras_layer*�{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 +�Troot.keras_api.metrics.0"_tf_keras_metric*�{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 17}2 +�Uroot.keras_api.metrics.1"_tf_keras_metric*�{"class_name": "MeanMetricWrapper", "name": "sparse_categorical_accuracy", "dtype": "float32", "config": {"name": "sparse_categorical_accuracy", "dtype": "float32", "fn": "sparse_categorical_accuracy"}, "shared_object_id": 18}2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy new file mode 100644 index 000000000..dd70331cf Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb new file mode 100644 index 000000000..771a58c62 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..0061f3865 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index new file mode 100644 index 000000000..06ba4b293 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index differ diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs new file mode 100644 index 000000000..e778a5a4a --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -0,0 +1,68 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow; +using Tensorflow.Keras.Optimizers; +using static Tensorflow.KerasApi; +using Tensorflow.NumPy; +using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; + +namespace TensorFlowNET.Keras.UnitTest.SaveModel; + +[TestClass] +public class SequentialModelLoad +{ + [TestMethod] + public void SimpleModelFromAutoCompile() + { + var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile"); + model.summary(); + + model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + // check the weights + var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy"); + var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy"); + + Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second)); + Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second)); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 8; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 50000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + } + + [TestMethod] + public void AlexnetFromSequential() + { + new SequentialModelSave().AlexnetFromSequential(); + var model = keras.models.load_model(@"./alexnet_from_sequential"); + model.summary(); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 8; + + var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs similarity index 94% rename from test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs rename to test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs index 269b9c058..fe9b8b71f 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs @@ -1,27 +1,21 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow.NumPy; -using System; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; +using System.Diagnostics; using Tensorflow; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; using Tensorflow.Keras; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; -using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; -using Tensorflow.Operations; -using System.Diagnostics; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest.SaveModel; [TestClass] -public class SequentialModelTest +public class SequentialModelSave { [TestMethod] public void SimpleModelFromAutoCompile() @@ -63,6 +57,8 @@ public void SimpleModelFromSequential() keras.layers.Softmax(1) }); + model.summary(); + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader(); @@ -82,7 +78,7 @@ public void SimpleModelFromSequential() } [TestMethod] - public void AlexModelFromSequential() + public void AlexnetFromSequential() { Model model = KerasApi.keras.Sequential(new List() { @@ -116,7 +112,7 @@ public void AlexModelFromSequential() keras.layers.Softmax(1) }); - model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); var num_epochs = 1; var batch_size = 8; @@ -125,7 +121,7 @@ public void AlexModelFromSequential() model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); - model.save("./pb_alex_sequential", save_format: "tf"); + model.save("./alexnet_from_sequential", save_format: "tf"); // The saved model can be test with the following python code: #region alexnet_python_code diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index c9020f7b4..bcd52c228 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -27,4 +27,28 @@ + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + +