diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs
index 63bdfd27d..a91b86827 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.cs
@@ -51,7 +51,17 @@ public static string StringPiece(IntPtr handle)
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}
- public unsafe static byte[] ByteStringPiece(IntPtr handle)
+ public unsafe static byte[] ByteStringPiece(Buffer? handle)
+ {
+ if (handle is null)
+ {
+ return new byte[0];
+ }
+ var data = handle.ToArray();
+ return data;
+ }
+
+ public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle)
{
if (handle == IntPtr.Zero)
{
@@ -66,7 +76,8 @@ public unsafe static byte[] ByteStringPiece(IntPtr handle)
current = *(str_data++);
bytes.Add(current);
}
- return bytes.Take(bytes.Count - 1).ToArray();
+ var data = bytes.ToArray();
+ return data;
}
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs
index d2aab9ac0..510e52eb7 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.customize.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs
@@ -10,7 +10,7 @@ public partial class c_api
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
+ public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
diff --git a/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs
new file mode 100644
index 000000000..2c20cfe9b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs
@@ -0,0 +1,25 @@
+using Tensorflow;
+
+internal static class GraphOnlyOps
+{
+ ///
+ /// Graph-only version of tf.compat.v1.placeholder(), for internal use only.
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null)
+ {
+ var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() };
+ var shape_value = new AttrValue() { Shape = shape.as_proto() };
+ var g = ops.get_default_graph();
+ Dictionary attrs = new();
+ attrs["dtype"] = dtype_value;
+ attrs["shape"] = shape_value;
+ var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype },
+ new TF_DataType[0], attrs: attrs, name: name);
+ var result = op.outputs[0];
+ return result;
+ }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
index ba7d7068e..6f7fa9c5f 100644
--- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
+++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
@@ -544,12 +544,12 @@ private static object _get_defun_input(object arg, string name)
Tensor placeholder;
try
{
- placeholder = tf.placeholder(tensor.dtype, tensor.shape, name);
+ placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name);
}
- catch (ValueError)
+ catch (ValueError ex)
{
- // TODO(Rinne): Add warning here.
- placeholder = tf.placeholder(tensor.dtype, tensor.shape);
+ tf.Logger.Warning(ex.ToString());
+ placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape);
}
handle_data_util.copy_handle_data(tensor, placeholder);
if (name is not null)
@@ -575,12 +575,12 @@ private static object _get_defun_input(object arg, string name)
Tensor placeholder;
try
{
- placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name);
+ placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
- placeholder = tf.placeholder(spec.dtype, spec.shape);
+ placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape);
}
if (name is not null)
{
diff --git a/src/TensorFlowNET.Core/Operations/list_ops.cs b/src/TensorFlowNET.Core/Operations/list_ops.cs
index c5e83ee41..3791a2c19 100644
--- a/src/TensorFlowNET.Core/Operations/list_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/list_ops.cs
@@ -31,7 +31,7 @@ private static Tensor _build_element_shape(Shape? shape)
}
else
{
- return ops.convert_to_tensor(shape);
+ return ops.convert_to_tensor(shape, dtype: dtypes.int32);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/while_v2.cs b/src/TensorFlowNET.Core/Operations/while_v2.cs
index 3f324f872..aae15b77d 100644
--- a/src/TensorFlowNET.Core/Operations/while_v2.cs
+++ b/src/TensorFlowNET.Core/Operations/while_v2.cs
@@ -38,9 +38,9 @@ public static Tensor[] while_loop(Func cond,
int len_orig_loop_vars = orig_loop_vars.Length;
loop_vars = _tensor_array_to_flow(loop_vars);
- loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors();
+ loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors();
- var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars));
+ var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars);
var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray();
@@ -379,10 +379,9 @@ private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph)
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder");
}
- private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype,
- string name)
+ private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value)
{
- return ops.convert_to_tensor(value, dtype, name, false);
+ return ops.convert_to_tensor(value, as_ref: false);
}
private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1)
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index fb9bccf31..7bd78a79f 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -576,7 +576,14 @@ public static bool inside_function()
public static HandleData get_resource_handle_data(Tensor graph_op)
{
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
- return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data));
+ try{
+ var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data));
+ return HandleData.Parser.ParseFrom(handle_str);
+ }
+ catch(Exception){
+ var handle_str = c_api.ByteStringPieceFromNativeString(handle_data.DangerousGetHandle());
+ return HandleData.Parser.ParseFrom(handle_str);
+ }
}
public static void dismantle_graph(Graph graph)
diff --git a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj
index 878077582..1ca387dbb 100644
--- a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj
+++ b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj
@@ -5,7 +5,7 @@
-
+