From 51b5f17c9a17397d61d1dc7df460517940e1107b Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Wed, 21 Jun 2023 21:41:06 +0800 Subject: [PATCH 1/3] fix: RNN training error on linux. --- src/TensorFlowNET.Core/APIs/c_api.cs | 14 ++++------- .../APIs/c_api.customize.cs | 2 +- src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs | 25 +++++++++++++++++++ src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 12 ++++----- src/TensorFlowNET.Core/Operations/list_ops.cs | 2 +- src/TensorFlowNET.Core/Operations/while_v2.cs | 9 +++---- src/TensorFlowNET.Core/ops.cs | 3 ++- 7 files changed, 44 insertions(+), 23 deletions(-) create mode 100644 src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 6049c95cc..d4744e789 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -51,17 +51,13 @@ public static string StringPiece(IntPtr handle) return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); } - public unsafe static byte[] ByteStringPiece(IntPtr handle) + public unsafe static byte[] ByteStringPiece(Buffer? handle) { - byte* str_data = (byte*)handle.ToPointer(); - List bytes = new List(); - byte current = 255; - while (current != ((byte)'\0')) - { - current = *(str_data++); - bytes.Add(current); + if(handle is null){ + return new byte[0]; } - return bytes.Take(bytes.Count - 1).ToArray(); + var data = handle.ToArray(); + return data; } [UnmanagedFunctionPointer(CallingConvention.Winapi)] diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs index d2aab9ac0..510e52eb7 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.customize.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs @@ -10,7 +10,7 @@ public partial class c_api [DllImport(TensorFlowLibName)] public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); [DllImport(TensorFlowLibName)] public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); } diff --git a/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs new file mode 100644 index 000000000..2c20cfe9b --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs @@ -0,0 +1,25 @@ +using Tensorflow; + +internal static class GraphOnlyOps +{ + /// + /// Graph-only version of tf.compat.v1.placeholder(), for internal use only. + /// + /// + /// + /// + /// + internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) + { + var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; + var shape_value = new AttrValue() { Shape = shape.as_proto() }; + var g = ops.get_default_graph(); + Dictionary attrs = new(); + attrs["dtype"] = dtype_value; + attrs["shape"] = shape_value; + var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, + new TF_DataType[0], attrs: attrs, name: name); + var result = op.outputs[0]; + return result; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index ba7d7068e..6f7fa9c5f 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -544,12 +544,12 @@ private static object _get_defun_input(object arg, string name) Tensor placeholder; try { - placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); + placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); } - catch (ValueError) + catch (ValueError ex) { - // TODO(Rinne): Add warning here. - placeholder = tf.placeholder(tensor.dtype, tensor.shape); + tf.Logger.Warning(ex.ToString()); + placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); } handle_data_util.copy_handle_data(tensor, placeholder); if (name is not null) @@ -575,12 +575,12 @@ private static object _get_defun_input(object arg, string name) Tensor placeholder; try { - placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); + placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); } catch (ValueError) { // TODO(Rinne): Add warning here. - placeholder = tf.placeholder(spec.dtype, spec.shape); + placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); } if (name is not null) { diff --git a/src/TensorFlowNET.Core/Operations/list_ops.cs b/src/TensorFlowNET.Core/Operations/list_ops.cs index c5e83ee41..3791a2c19 100644 --- a/src/TensorFlowNET.Core/Operations/list_ops.cs +++ b/src/TensorFlowNET.Core/Operations/list_ops.cs @@ -31,7 +31,7 @@ private static Tensor _build_element_shape(Shape? shape) } else { - return ops.convert_to_tensor(shape); + return ops.convert_to_tensor(shape, dtype: dtypes.int32); } } diff --git a/src/TensorFlowNET.Core/Operations/while_v2.cs b/src/TensorFlowNET.Core/Operations/while_v2.cs index 3f324f872..aae15b77d 100644 --- a/src/TensorFlowNET.Core/Operations/while_v2.cs +++ b/src/TensorFlowNET.Core/Operations/while_v2.cs @@ -38,9 +38,9 @@ public static Tensor[] while_loop(Func cond, int len_orig_loop_vars = orig_loop_vars.Length; loop_vars = _tensor_array_to_flow(loop_vars); - loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); + loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors(); - var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); + var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars); var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); @@ -379,10 +379,9 @@ private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); } - private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, - string name) + private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value) { - return ops.convert_to_tensor(value, dtype, name, false); + return ops.convert_to_tensor(value, as_ref: false); } private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index fb9bccf31..a962e6d87 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -576,7 +576,8 @@ public static bool inside_function() public static HandleData get_resource_handle_data(Tensor graph_op) { var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); + var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); + return HandleData.Parser.ParseFrom(handle_str); } public static void dismantle_graph(Graph graph) From 69b3bce3309d62b26d91614a1e2430ff0e5b183c Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 22 Jun 2023 02:07:10 +0800 Subject: [PATCH 2/3] test: update the redist version of test. --- .../Tensorflow.UnitTest.RedistHolder.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj index 878077582..1ca387dbb 100644 --- a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj +++ b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj @@ -5,7 +5,7 @@ - + From ae8fe840e457b0b34d04fc0cafdb31d89b7a9d4d Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 22 Jun 2023 09:21:18 +0800 Subject: [PATCH 3/3] fix: resolve conflict. --- src/TensorFlowNET.Core/APIs/c_api.cs | 4 +++- src/TensorFlowNET.Core/ops.cs | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 559176a54..a91b86827 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -53,8 +53,10 @@ public static string StringPiece(IntPtr handle) public unsafe static byte[] ByteStringPiece(Buffer? handle) { - if(handle is null){ + if (handle is null) + { return new byte[0]; + } var data = handle.ToArray(); return data; } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index a962e6d87..7bd78a79f 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -576,8 +576,14 @@ public static bool inside_function() public static HandleData get_resource_handle_data(Tensor graph_op) { var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); - return HandleData.Parser.ParseFrom(handle_str); + try{ + var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); + return HandleData.Parser.ParseFrom(handle_str); + } + catch(Exception){ + var handle_str = c_api.ByteStringPieceFromNativeString(handle_data.DangerousGetHandle()); + return HandleData.Parser.ParseFrom(handle_str); + } } public static void dismantle_graph(Graph graph)