diff --git a/src/TensorFlowNET.Core/APIs/tf.saved_model.cs b/src/TensorFlowNET.Core/APIs/tf.saved_model.cs new file mode 100644 index 000000000..ef6251ca8 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.saved_model.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow +{ + public partial class tensorflow + { + public SavedModelAPI saved_model { get; } = new SavedModelAPI(); + } + + public class SavedModelAPI + { + public Trackable load(string export_dir, LoadOptions? options = null) + { + return Loader.load(export_dir, options); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index ea4159694..3bce52ea5 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -8,6 +8,7 @@ using Tensorflow.Framework; using Tensorflow.Framework.Models; using Tensorflow.Functions; +using Tensorflow.NumPy; using Tensorflow.Operations; using Tensorflow.Util; using static Tensorflow.Binding; @@ -181,7 +182,7 @@ public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType const int _EAGER_CONST_THRESHOLD = 128; public Tensor capture(Tensor tensor, string name = null, Shape shape = null) { - if(tensor is EagerTensor) + if(tensor is EagerTensor or NDArray) { if (name == null) name = ops.uid().ToString(); diff --git a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs index 58e7ef8c1..5458a5368 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs @@ -10,4 +10,5 @@ void apply_gradients((Tensor, IVariableV1) grads_and_vars, void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, string name = null, bool experimental_aggregate_gradients = true); + IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null); } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 4261d72b7..311f2184f 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -216,10 +216,12 @@ public virtual T[] get_attr_list(string name) public virtual object get_attr(string name) { var buf = new Buffer(); - c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); - tf.Status.Check(true); + Status status = new(); + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + var tf_buffer = c_api.TF_GetBuffer(buf); - var x = AttrValue.Parser.ParseFrom(buf.ToArray()); + var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan()); var oneof_value = x.ValueCase; if (oneof_value == AttrValue.ValueOneofCase.None) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 19dbd6edf..25bb88826 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -64,36 +64,68 @@ public static NDArray MakeNdarray(TensorProto tensor) var num_elements = shape.size; var tensor_dtype = tensor.Dtype.as_tf_dtype(); + T[] ExpandArrayToSize(IList src) + { + if(src.Count == 0) + { + return new T[0]; + } + var pad_count = num_elements - src.Count; + var pre = pad_count / 2; + var after = pad_count - pre; + var first_elem = src[0]; + var last_elem = src[src.Count - 1]; + T[] res = new T[num_elements]; + for(long i = 0; i < num_elements; i++) + { + if (i < pre) res[i] = first_elem; + else if (i >= num_elements - after) res[i] = last_elem; + else res[i] = src[(int)(i - pre)]; + } + return res; + } + if (shape.ndim > 0 && tensor.TensorContent.Length > 0) { return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); } - else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) + NDArray values; + if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) { - return np.array(tensor.HalfVal.ToArray()).reshape(shape); + values = np.array(ExpandArrayToSize(tensor.HalfVal)); } else if (tensor.Dtype == DataType.DtFloat) { - return np.array(tensor.FloatVal.ToArray()).reshape(shape); + values = np.array(ExpandArrayToSize(tensor.FloatVal)); } else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) { - return np.array(tensor.IntVal.ToArray()).reshape(shape); + values = np.array(ExpandArrayToSize(tensor.IntVal)); } else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) { - return np.array(tensor.Int64Val.ToArray()).reshape(shape); + values = np.array(ExpandArrayToSize(tensor.Int64Val)); } else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) { - return np.array(tensor.Uint64Val.ToArray()).reshape(shape); + values = np.array(ExpandArrayToSize(tensor.Uint64Val)); } else if (tensor.Dtype == DataType.DtBool) { - return np.array(tensor.BoolVal.ToArray()).reshape(shape); + values = np.array(ExpandArrayToSize(tensor.BoolVal)); + } + else + { + throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " + + $"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); + } + + if(values.size == 0) + { + return np.zeros(shape, tensor_dtype); } - throw new NotImplementedException("MakeNdarray"); + return values.reshape(shape); } private static readonly TF_DataType[] quantized_types = new TF_DataType[] diff --git a/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs index 6de8274a1..d65446f3d 100644 --- a/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs +++ b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs @@ -1,5 +1,6 @@ using Google.Protobuf.Collections; using Tensorflow.Train; +using static Tensorflow.Binding; namespace Tensorflow.Trackables; @@ -11,12 +12,23 @@ public TrackableConstant(Tensor constant) _constant = constant; } - public static (Trackable, Action) deserialize_from_proto(SavedObject object_proto, + public static (Tensor, Action) deserialize_from_proto(SavedObject object_proto, Dictionary> operation_attributes) { var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; var ndarray = tensor_util.MakeNdarray(tensor_proto); - var imported_constant = constant_op.constant(ndarray); - return (new TrackableConstant(imported_constant), null); + Tensor imported_constant; + if (tensor_proto.Dtype == DataType.DtString) + { + imported_constant = tf_with(ops.device("CPU"), _ => + { + return constant_op.constant(ndarray); + }); + } + else + { + imported_constant = constant_op.constant(ndarray); + } + return (imported_constant, null); } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs index 5bb7238e7..ab6adc30f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -46,4 +46,9 @@ public static (Trackable, Action) deserialize(SavedUserO return (null, null); } } + + public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj) + { + _registered_revived_creator[identifier] = obj; + } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 5752d7284..b7d987e71 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -137,7 +137,7 @@ public List get_concrete_resource_initializers() /// public List dependency_sorted_node_ids() { - Dictionary> dependency_map = new(); + Dictionary> dependency_map = new(); foreach (var node in _nodes) { var node_id = _node_ids[node]; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index d6986af3d..af9fbeda5 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -116,17 +116,23 @@ public static Dictionary load_function_def_library(Fun } Dictionary loaded_gradients = new(); - foreach (var fdef in _sort_function_defs(library, function_deps)) + // Debug(Rinne) + var temp = _sort_function_defs(library, function_deps); + int i = 0; + foreach (var fdef in temp) { + i++; var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); object structured_input_signature = null; object structured_outputs = null; if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) { - var proto = saved_object_graph.ConcreteFunctions[orig_name]; - structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); - structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); + // TODO(Rinne): deal with structured_input_signature and structured_outputs. + + //var proto = saved_object_graph.ConcreteFunctions[orig_name]; + //structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); + //structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); } graph.as_default(); @@ -234,27 +240,41 @@ private static Func _gen_gradient_func(ConcreteFu private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary renamed_functions, Dictionary loaded_gradients) { - foreach(var op in func_graph.get_operations()) + if(loaded_gradients is null || loaded_gradients.Count == 0) { - if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") - { - var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; - op.op._gradient_function = function._get_gradient_function(); - } - string gradient_op_type = null; - try - { - gradient_op_type = op.op.get_attr("_gradient_op_type") as string; - } - catch(InvalidArgumentError) + foreach (var op in func_graph.get_operations()) { - continue; + if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") + { + var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; + op.op._gradient_function = function._get_gradient_function(); + } } - if (loaded_gradients.ContainsKey(gradient_op_type)) + } + else + { + foreach (var op in func_graph.get_operations()) { - var grad_fn = loaded_gradients[gradient_op_type]; - grad_fn.NumPositionArgs = op.op.inputs.Length; - grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); + if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") + { + var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; + op.op._gradient_function = function._get_gradient_function(); + } + string gradient_op_type = null; + try + { + gradient_op_type = op.op.get_attr("_gradient_op_type") as string; + } + catch (InvalidArgumentError) + { + continue; + } + if (loaded_gradients.ContainsKey(gradient_op_type)) + { + var grad_fn = loaded_gradients[gradient_op_type]; + grad_fn.NumPositionArgs = op.op.inputs.Length; + grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); + } } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index cad32c59d..ae7e2cf5a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -15,6 +15,7 @@ using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Trackables; using OneOf; +using Tensorflow.Keras.Engine; namespace Tensorflow { @@ -34,7 +35,7 @@ public partial class Loader private List? _filtered_nodes; private List _ordered_node_ids; private Dictionary)> _loaded_nodes; - private List _nodes; + private List _nodes; private Dictionary> _node_setters; private Dictionary _concrete_functions; private HashSet _restored_concrete_functions; @@ -213,7 +214,13 @@ private List _generate_ordered_node_ids() continue; } var proto = _proto.Nodes[node_id]; - foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) + if(node_id == 10522) + { + // Debug(Rinne) + Console.WriteLine(); + } + var temp = _get_node_dependencies(proto); + foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) { deps.Add(dep); if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) @@ -232,7 +239,7 @@ private List _generate_ordered_node_ids() // The optimizer and original variable must be created before the slot // variable, since the slot variable is generated using the Optimizer's // add_slot API. - var slot_deps = dependency_map[slot_variable_node_id]; + var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List()); slot_deps.Add(node_id); slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); @@ -245,7 +252,12 @@ private List _generate_ordered_node_ids() } try { - return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable)); + int total = 0; + foreach(var v in dependency_map.Values) + { + total += v.Count; + } + return TrackableUtils.order_by_dependency(dependency_map); } catch (TrackableUtils.CyclicDependencyError ex) { @@ -339,9 +351,20 @@ private void _load_checkpoint_save_and_restore_functions() var saveable_object_proto = item.Value; var save_fn_id = saveable_object_proto.SaveFunction; var restore_fn_id = saveable_object_proto.RestoreFunction; - saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); + saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id)); + } + var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); + if (saveable_objects is not null && saveable_objects.Count > 0) + { + if(node is Trackable trackable) + { + trackable.SelfSaveableObjectFactories = saveable_objects; + } + else + { + throw new TypeError(); + } } - node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); } } } @@ -379,12 +402,12 @@ private void _load_nodes() { // Use the public Optimizer interface when creating slot variables. var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; - var optimizer_object = nodes[optimizer_node_id]; + var optimizer_object = nodes[optimizer_node_id] as IOptimizer; var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; - // TODO(Rinne): implement it. - throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + - " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName); + nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable; + node_setters[slot_variable_proto.SlotVariableNodeId] = setattr; } else { @@ -398,7 +421,7 @@ private void _load_nodes() { nodes[0] = _recreate_base_user_object().Item1; } - _nodes = new List(); + _nodes = new List(); for(int i = 0; i < _proto.Nodes.Count; i++) { _nodes.Add(nodes[i]); @@ -412,7 +435,7 @@ private void _load_nodes() private void _restore_checkpoint() { var variables_path = SavedModelUtils.get_variables_path(_export_dir); - var saver = new TrackableSaver(new ObjectGraphView(get(0))); + var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0))); tf_with(ops.device("CPU"), _ => { saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); @@ -467,7 +490,7 @@ private void _load_edges() } } - private void _setup_function_captures(string concrete_function_name, IDictionary, Trackable> nodes) + private void _setup_function_captures(string concrete_function_name, IDictionary, object> nodes) { if (_restored_concrete_functions.Contains(concrete_function_name)) { @@ -485,12 +508,12 @@ private void _setup_remaining_functions() // TODO: implement it with concrete functions. } - public Trackable get(int node_id) + public object get(int node_id) { return _nodes[node_id]; } - public Trackable get(string node_id) + public object get(string node_id) { return get(_node_path_to_id[node_id]); } @@ -512,9 +535,9 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) } } - private (Dictionary, Dictionary>) _initialize_loaded_nodes() + private (Dictionary, Dictionary>) _initialize_loaded_nodes() { - Dictionary nodes = new(); + Dictionary nodes = new(); Dictionary> node_setters = new(); foreach(var item in _loaded_nodes) { @@ -534,10 +557,10 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) } } - private (Trackable, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) + private (object, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) { // skip the registered classes. - Dictionary, Trackable> dependencies = new(); + Dictionary, object> dependencies = new(); foreach(var item in _get_node_dependencies(proto)) { dependencies[item.Key] = nodes[item.Value]; @@ -558,7 +581,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) /// /// /// - private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, Trackable> dependencies) + private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, object> dependencies) { return proto.KindCase switch { @@ -626,7 +649,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) } private (Function, Action) _recreate_function(SavedFunction proto, - IDictionary, Trackable> dependencies) + IDictionary, object> dependencies) { var fn = function_deserialization.recreate_function(proto, _concrete_functions); foreach (var name in proto.ConcreteFunctions) @@ -637,7 +660,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) } private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, - IDictionary, Trackable> dependencies) + IDictionary, object> dependencies) { var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); _setup_function_captures(proto.ConcreteFunctionName, dependencies); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs index a92cb5509..d1c0170c8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs @@ -78,7 +78,7 @@ public static IDictionary load_partial(string export_dir, IDi tf_with(ops.init_scope(), x => { loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); - root = loader.get(0); + root = (Trackable)loader.get(0); // skip the assignment of `graph_debug_info`. }); // skip the assignment of `tensorflow_version` @@ -99,7 +99,7 @@ public static IDictionary load_partial(string export_dir, IDi } if(filters != null && filters.Count > 0) { - return filters.Keys.ToDictionary(x => x, x => loader.get(x)); + return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x)); } else { diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 05c513a83..89bb614d2 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -52,7 +52,7 @@ public static string checkpoint_key(string object_path, string local_name) /// /// /// - public static List order_by_dependency(IDictionary> dependency_map) + public static List order_by_dependency(IDictionary> dependency_map) { Dictionary> reverse_dependency_map = new(); foreach (var pair in dependency_map) @@ -102,7 +102,7 @@ public static List order_by_dependency(IDictionary> d edges.Remove(x); if (edges.Count == 0) { - to_visit.Enqueue(dep); + to_visit.Enqueue(dep); if (!reverse_dependency_map.Remove(dep)) { throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 64728020c..64fe0ec84 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -333,5 +333,23 @@ public Tensor read_value_no_copy() }); return array_ops.identity(value); } + + //public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y; + //public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y; + //public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y; + //public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value(); + //public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value(); + + //public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value(); + //public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y; + //public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y; + + //public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y; + + //public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y; } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs index 29d6106b5..2737a2191 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs @@ -1,19 +1,6 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - +using System; +using System.Collections.Generic; +using System.Text; using Tensorflow.NumPy; namespace Tensorflow diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index d13990a09..80403ad6a 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -169,6 +169,12 @@ public void set_learning_phase(bool value) _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); } + public void set_value(IVariableV1 x, object value) + { + // TODO(Rinne): check the implementation. + x.assign(value); + } + public void batch_set_value(List<(IVariableV1, NDArray)> tuples) { if (ops.executing_eagerly_outside_functions()) diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 7c6a692ef..159564aac 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -36,6 +36,11 @@ public static KerasInterface Instance } } + static KerasInterface() + { + RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer()); + } + public KerasDataset datasets { get; } = new KerasDataset(); public IInitializersApi initializers { get; } = new InitializersApi(); public Regularizers regularizers { get; } = new Regularizers(); diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs index e49d757a0..44c163bc8 100644 --- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs @@ -14,11 +14,11 @@ public class OptimizerV2 : Trackable, IOptimizer protected bool _hypers_created; protected virtual string _name { get; } - IVariableV1 _iterations; + protected IVariableV1 _iterations; protected ResourceVariable iterations => _iterations as ResourceVariable; List _weights; - Dictionary _hyper; - Dictionary _hyper_variables; + protected Dictionary _hyper; + protected Dictionary _hyper_variables; protected bool _momentum; protected float _initial_decay = 0.0f; protected bool _use_locking = true; @@ -224,7 +224,7 @@ protected virtual void _create_slots(IVariableV1[] var_list) } } - protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) + public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) { if (initializer == null) initializer = tf.zeros_initializer; diff --git a/src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs b/src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs new file mode 100644 index 000000000..e5cfd2daa --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Saving; +using Tensorflow.Train; +using Tensorflow.Training; + +namespace Tensorflow.Keras.Optimizers +{ + public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig + { + public String Identifier { get; } = "optimizer"; + public int Version { get; } = 2; + public int MinConsumerVersion { get; } = 1; + public int MinProducerVersion { get; } = 1; + public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" }) + { + _hypers_created = true; + } + + public IKerasConfig get_config() + { + throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " + + "supported. Please file a feature request if this limitation bothers you."); + } + + public void SetValue(object name, object value) + { + if(name is not String str) + { + throw new TypeError($"The name of value to set must be string, but got {name.GetType()}"); + } + if(value is Trackable trackable) + { + _track_trackable(trackable, str, overwrite: true); + } + if(value is IVariableV1 resource_variable) + { + if (!_hyper_variables.ContainsKey(str)) + { + _hyper_variables[str] = resource_variable; + } + else + { + keras.backend.set_value(resource_variable, value); + } + } + else if (value is float f) + { + _hyper[str] = f; + } + else + { + throw new NotImplementedException(); + } + } + + public Trackable FromProto(SavedUserObject proto) + { + return new RestoredOptimizer(); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 806c4ece8..7a5aee0f4 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -2,6 +2,7 @@ using System; using System.Linq; using Tensorflow; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.NumPy; @@ -103,4 +104,13 @@ public void VGG19() classify_model.fit(x, y, batch_size: 4); } + + [Ignore] + [TestMethod] + public void TestModelBeforeTF2_5() + { + var a = keras.layers; + var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model; + model.summary(); + } }