diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 63bdfd27d..a91b86827 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -51,7 +51,17 @@ public static string StringPiece(IntPtr handle) return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); } - public unsafe static byte[] ByteStringPiece(IntPtr handle) + public unsafe static byte[] ByteStringPiece(Buffer? handle) + { + if (handle is null) + { + return new byte[0]; + } + var data = handle.ToArray(); + return data; + } + + public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle) { if (handle == IntPtr.Zero) { @@ -66,7 +76,8 @@ public unsafe static byte[] ByteStringPiece(IntPtr handle) current = *(str_data++); bytes.Add(current); } - return bytes.Take(bytes.Count - 1).ToArray(); + var data = bytes.ToArray(); + return data; } [UnmanagedFunctionPointer(CallingConvention.Winapi)] diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs index d2aab9ac0..510e52eb7 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.customize.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs @@ -10,7 +10,7 @@ public partial class c_api [DllImport(TensorFlowLibName)] public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); [DllImport(TensorFlowLibName)] public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); } diff --git a/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs new file mode 100644 index 000000000..2c20cfe9b --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs @@ -0,0 +1,25 @@ +using Tensorflow; + +internal static class GraphOnlyOps +{ + /// + /// Graph-only version of tf.compat.v1.placeholder(), for internal use only. + /// + /// + /// + /// + /// + internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) + { + var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; + var shape_value = new AttrValue() { Shape = shape.as_proto() }; + var g = ops.get_default_graph(); + Dictionary attrs = new(); + attrs["dtype"] = dtype_value; + attrs["shape"] = shape_value; + var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, + new TF_DataType[0], attrs: attrs, name: name); + var result = op.outputs[0]; + return result; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index ba7d7068e..6f7fa9c5f 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -544,12 +544,12 @@ private static object _get_defun_input(object arg, string name) Tensor placeholder; try { - placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); + placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); } - catch (ValueError) + catch (ValueError ex) { - // TODO(Rinne): Add warning here. - placeholder = tf.placeholder(tensor.dtype, tensor.shape); + tf.Logger.Warning(ex.ToString()); + placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); } handle_data_util.copy_handle_data(tensor, placeholder); if (name is not null) @@ -575,12 +575,12 @@ private static object _get_defun_input(object arg, string name) Tensor placeholder; try { - placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); + placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); } catch (ValueError) { // TODO(Rinne): Add warning here. - placeholder = tf.placeholder(spec.dtype, spec.shape); + placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); } if (name is not null) { diff --git a/src/TensorFlowNET.Core/Operations/list_ops.cs b/src/TensorFlowNET.Core/Operations/list_ops.cs index c5e83ee41..3791a2c19 100644 --- a/src/TensorFlowNET.Core/Operations/list_ops.cs +++ b/src/TensorFlowNET.Core/Operations/list_ops.cs @@ -31,7 +31,7 @@ private static Tensor _build_element_shape(Shape? shape) } else { - return ops.convert_to_tensor(shape); + return ops.convert_to_tensor(shape, dtype: dtypes.int32); } } diff --git a/src/TensorFlowNET.Core/Operations/while_v2.cs b/src/TensorFlowNET.Core/Operations/while_v2.cs index 3f324f872..aae15b77d 100644 --- a/src/TensorFlowNET.Core/Operations/while_v2.cs +++ b/src/TensorFlowNET.Core/Operations/while_v2.cs @@ -38,9 +38,9 @@ public static Tensor[] while_loop(Func cond, int len_orig_loop_vars = orig_loop_vars.Length; loop_vars = _tensor_array_to_flow(loop_vars); - loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); + loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors(); - var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); + var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars); var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); @@ -379,10 +379,9 @@ private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); } - private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, - string name) + private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value) { - return ops.convert_to_tensor(value, dtype, name, false); + return ops.convert_to_tensor(value, as_ref: false); } private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index fb9bccf31..7bd78a79f 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -576,7 +576,14 @@ public static bool inside_function() public static HandleData get_resource_handle_data(Tensor graph_op) { var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); + try{ + var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); + return HandleData.Parser.ParseFrom(handle_str); + } + catch(Exception){ + var handle_str = c_api.ByteStringPieceFromNativeString(handle_data.DangerousGetHandle()); + return HandleData.Parser.ParseFrom(handle_str); + } } public static void dismantle_graph(Graph graph) diff --git a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj index 878077582..1ca387dbb 100644 --- a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj +++ b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj @@ -5,7 +5,7 @@ - +