Skip to content

Support loading of SavedModel format #989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,22 @@ public static void add_checkpoint_values_check(TrackableObjectGraph object_graph
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}

/// <summary>
/// Traverse the object graph and list all accessible objects.
/// </summary>
/// <param name="object_graph_view"></param>
public static IList<Trackable> list_objects(ObjectGraphView graph_view)
{
return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
}

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;
});
}
}
100 changes: 100 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
using Tensorflow.Util;

namespace Tensorflow.Checkpoint
{
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
{
public SafeCheckpointReaderHandle(): base()
{

}
public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
{

}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
public class CheckpointReader
{
private SafeCheckpointReaderHandle _handle;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

public CheckpointReader(string filename)
{
Status status = new Status();
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}

public int HasTensor(string name)
{
return c_api.TF_CheckpointReaderHasTensor(_handle, name);
}

/// <summary>
/// Get the variable name.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public string GetVariable(int index)
{
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
}

public int Size()
{
return c_api.TF_CheckpointReaderSize(_handle);
}

public TF_DataType GetVariableDataType(string name)
{
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
}

public Shape GetVariableShape(string name)
{
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}

public int GetVariableNumDims(string name)
{
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
}

public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
status.Check(true);
return new Tensor(tensor);
}

private void ReadAllShapeAndType()
{
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>();
int size = Size();
for(int i = 0; i < size; i++)
{
var name = GetVariable(i);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
VariableToDataTypeMap[name] = dtype;
VariableToShapeMap[name] = shape;
}
}
}
}
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ public static (IList<MySaveableObject>, object?) generate_saveable_objects(
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
var maybe_saveable = factory_data.factory;
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory);

// TODO: oneflow python has a process with callable `saveable_factory`.
// TODO: tensorflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
{
Expand Down Expand Up @@ -217,7 +217,7 @@ public static (IList<MySaveableObject>, object?) generate_saveable_objects(

public record class CheckpointFactoryData
(
Maybe<BaseResourceVariable, MySaveableObject> factory,
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
string name,
string checkpoint_key
);
27 changes: 27 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System.Runtime.InteropServices;
using Tensorflow.Checkpoint;

namespace Tensorflow
{
public unsafe partial class c_api
{
[DllImport(TensorFlowLibName)]
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader);
[DllImport(TensorFlowLibName)]
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status);
}
}
Loading