Skip to content

Commit 179c3f0

Browse files
authored
Merge pull request #1169 from lingbai-kong/ndarrayload
add: loading pickled npy file for imdb dataset loader
2 parents 70d681c + f57a6fe commit 179c3f0

File tree

14 files changed

+546
-63
lines changed

14 files changed

+546
-63
lines changed

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using System.Linq;
55
using System.Text;
66
using Tensorflow.Util;
7+
using Razorvine.Pickle;
8+
using Tensorflow.NumPy.Pickle;
79
using static Tensorflow.Binding;
810

911
namespace Tensorflow.NumPy
@@ -97,6 +99,13 @@ Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, i
9799
return matrix;
98100
}
99101

102+
Array ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape)
103+
{
104+
Stream stream = reader.BaseStream;
105+
var unpickler = new Unpickler();
106+
return (MultiArrayPickleWarpper)unpickler.load(stream);
107+
}
108+
100109
public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false)
101110
{
102111
var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse);

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@ public Array LoadMatrix(Stream stream)
2727
Array matrix = Array.CreateInstance(type, shape);
2828

2929
//if (type == typeof(String))
30-
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
31-
return ReadValueMatrix(reader, matrix, bytes, type, shape);
30+
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
31+
32+
if (type == typeof(Object))
33+
return ReadObjectMatrix(reader, matrix, shape);
34+
else
35+
{
36+
return ReadValueMatrix(reader, matrix, bytes, type, shape);
37+
}
3238
}
3339
}
3440

@@ -37,7 +43,7 @@ public T Load<T>(Stream stream)
3743
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
3844
{
3945
// if (typeof(T).IsArray && (typeof(T).GetElementType().IsArray || typeof(T).GetElementType() == typeof(string)))
40-
// return LoadJagged(stream) as T;
46+
// return LoadJagged(stream) as T;
4147
return LoadMatrix(stream) as T;
4248
}
4349

@@ -93,7 +99,7 @@ bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape
9399
Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
94100
{
95101
isLittleEndian = IsLittleEndian(dtype);
96-
bytes = Int32.Parse(dtype.Substring(2));
102+
bytes = dtype.Length > 2 ? Int32.Parse(dtype.Substring(2)) : 0;
97103

98104
string typeCode = dtype.Substring(1);
99105

@@ -121,6 +127,8 @@ Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
121127
return typeof(Double);
122128
if (typeCode.StartsWith("S"))
123129
return typeof(String);
130+
if (typeCode.StartsWith("O"))
131+
return typeof(Object);
124132

125133
throw new NotSupportedException();
126134
}

src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ public class RandomizedImpl
1414
public NDArray permutation(NDArray x) => new NDArray(random_ops.random_shuffle(x));
1515

1616
[AutoNumPy]
17-
public void shuffle(NDArray x)
17+
public void shuffle(NDArray x, int? seed = null)
1818
{
19-
var y = random_ops.random_shuffle(x);
19+
var y = random_ops.random_shuffle(x, seed);
2020
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize);
2121
}
2222

src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public class NDArrayConverter
1010
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
1111
=> nd.dtype switch
1212
{
13+
TF_DataType.TF_BOOL => Scalar<T>(*(bool*)nd.data),
1314
TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data),
1415
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
1516
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.NumPy.Pickle
6+
{
7+
public class DTypePickleWarpper
8+
{
9+
TF_DataType dtype { get; set; }
10+
public DTypePickleWarpper(TF_DataType dtype)
11+
{
12+
this.dtype = dtype;
13+
}
14+
public void __setstate__(object[] args) { }
15+
public static implicit operator TF_DataType(DTypePickleWarpper dTypeWarpper)
16+
{
17+
return dTypeWarpper.dtype;
18+
}
19+
}
20+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.CodeAnalysis;
4+
using System.Text;
5+
using Razorvine.Pickle;
6+
7+
namespace Tensorflow.NumPy.Pickle
8+
{
9+
/// <summary>
10+
///
11+
/// </summary>
12+
[SuppressMessage("ReSharper", "InconsistentNaming")]
13+
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
14+
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
15+
class DtypeConstructor : IObjectConstructor
16+
{
17+
public object construct(object[] args)
18+
{
19+
var typeCode = (string)args[0];
20+
TF_DataType dtype;
21+
if (typeCode == "b1")
22+
dtype = np.@bool;
23+
else if (typeCode == "i1")
24+
dtype = np.@byte;
25+
else if (typeCode == "i2")
26+
dtype = np.int16;
27+
else if (typeCode == "i4")
28+
dtype = np.int32;
29+
else if (typeCode == "i8")
30+
dtype = np.int64;
31+
else if (typeCode == "u1")
32+
dtype = np.ubyte;
33+
else if (typeCode == "u2")
34+
dtype = np.uint16;
35+
else if (typeCode == "u4")
36+
dtype = np.uint32;
37+
else if (typeCode == "u8")
38+
dtype = np.uint64;
39+
else if (typeCode == "f4")
40+
dtype = np.float32;
41+
else if (typeCode == "f8")
42+
dtype = np.float64;
43+
else if (typeCode.StartsWith("S"))
44+
dtype = np.@string;
45+
else if (typeCode.StartsWith("O"))
46+
dtype = np.@object;
47+
else
48+
throw new NotSupportedException();
49+
return new DTypePickleWarpper(dtype);
50+
}
51+
}
52+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.CodeAnalysis;
4+
using System.Text;
5+
using Razorvine.Pickle;
6+
using Razorvine.Pickle.Objects;
7+
8+
namespace Tensorflow.NumPy.Pickle
9+
{
10+
/// <summary>
11+
/// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if
12+
/// the objects are ints, etc.
13+
/// </summary>
14+
[SuppressMessage("ReSharper", "InconsistentNaming")]
15+
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
16+
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
17+
public class MultiArrayConstructor : IObjectConstructor
18+
{
19+
public object construct(object[] args)
20+
{
21+
if (args.Length != 3)
22+
throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments.");
23+
24+
var types = (ClassDictConstructor)args[0];
25+
if (types.module != "numpy" || types.name != "ndarray")
26+
throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray");
27+
28+
var arg1 = (object[])args[1];
29+
var dims = new int[arg1.Length];
30+
for (var i = 0; i < arg1.Length; i++)
31+
{
32+
dims[i] = (int)arg1[i];
33+
}
34+
var shape = new Shape(dims);
35+
36+
TF_DataType dtype;
37+
string identifier;
38+
if (args[2].GetType() == typeof(string))
39+
identifier = (string)args[2];
40+
else
41+
identifier = Encoding.UTF8.GetString((byte[])args[2]);
42+
switch (identifier)
43+
{
44+
case "u": dtype = np.uint32; break;
45+
case "c": dtype = np.complex_; break;
46+
case "f": dtype = np.float32; break;
47+
case "b": dtype = np.@bool; break;
48+
default: throw new NotImplementedException($"Unsupported data type: {args[2]}");
49+
}
50+
return new MultiArrayPickleWarpper(shape, dtype);
51+
}
52+
}
53+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
using Newtonsoft.Json.Linq;
2+
using Serilog.Debugging;
3+
using System;
4+
using System.Collections;
5+
using System.Collections.Generic;
6+
using System.Text;
7+
8+
namespace Tensorflow.NumPy.Pickle
9+
{
10+
public class MultiArrayPickleWarpper
11+
{
12+
public Shape reconstructedShape { get; set; }
13+
public TF_DataType reconstructedDType { get; set; }
14+
public NDArray reconstructedNDArray { get; set; }
15+
public Array reconstructedMultiArray { get; set; }
16+
public MultiArrayPickleWarpper(Shape shape, TF_DataType dtype)
17+
{
18+
reconstructedShape = shape;
19+
reconstructedDType = dtype;
20+
}
21+
public void __setstate__(object[] args)
22+
{
23+
if (args.Length != 5)
24+
throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments.");
25+
26+
var version = (int)args[0]; // version
27+
28+
var arg1 = (object[])args[1];
29+
var dims = new int[arg1.Length];
30+
for (var i = 0; i < arg1.Length; i++)
31+
{
32+
dims[i] = (int)arg1[i];
33+
}
34+
var _ShapeLike = new Shape(dims); // shape
35+
36+
TF_DataType _DType_co = (DTypePickleWarpper)args[2]; // DType
37+
38+
var F_continuous = (bool)args[3]; // F-continuous
39+
if (F_continuous)
40+
throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format.");
41+
42+
var data = args[4]; // Data
43+
/*
44+
* If we ever need another pickle format, increment the version
45+
* number. But we should still be able to handle the old versions.
46+
*/
47+
if (version < 0 || version > 4)
48+
throw new ValueError($"can't handle version {version} of numpy.dtype pickle");
49+
50+
// TODO: Implement the missing details and checks from the official Numpy C code here.
51+
// https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761
52+
53+
if (data.GetType() == typeof(ArrayList))
54+
{
55+
Reconstruct((ArrayList)data);
56+
}
57+
else
58+
throw new NotImplementedException("");
59+
}
60+
private void Reconstruct(ArrayList arrayList)
61+
{
62+
int ndim = 1;
63+
var subArrayList = arrayList;
64+
while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList))
65+
{
66+
subArrayList = (ArrayList)subArrayList[0];
67+
ndim += 1;
68+
}
69+
var type = subArrayList[0].GetType();
70+
if (type == typeof(int))
71+
{
72+
if (ndim == 1)
73+
{
74+
int[] list = (int[])arrayList.ToArray(typeof(int));
75+
Shape shape = new Shape(new int[] { arrayList.Count });
76+
reconstructedMultiArray = list;
77+
reconstructedNDArray = new NDArray(list, shape);
78+
}
79+
if (ndim == 2)
80+
{
81+
int secondDim = 0;
82+
foreach (ArrayList subArray in arrayList)
83+
{
84+
secondDim = subArray.Count > secondDim ? subArray.Count : secondDim;
85+
}
86+
int[,] list = new int[arrayList.Count, secondDim];
87+
for (int i = 0; i < arrayList.Count; i++)
88+
{
89+
var subArray = (ArrayList?)arrayList[i];
90+
if (subArray == null)
91+
throw new NullReferenceException("");
92+
for (int j = 0; j < subArray.Count; j++)
93+
{
94+
var element = subArray[j];
95+
if (element == null)
96+
throw new NoNullAllowedException("the element of ArrayList cannot be null.");
97+
list[i, j] = (int)element;
98+
}
99+
}
100+
Shape shape = new Shape(new int[] { arrayList.Count, secondDim });
101+
reconstructedMultiArray = list;
102+
reconstructedNDArray = new NDArray(list, shape);
103+
}
104+
if (ndim > 2)
105+
throw new NotImplementedException("can't handle ArrayList with more than two dimensions.");
106+
}
107+
else
108+
throw new NotImplementedException("");
109+
}
110+
public static implicit operator Array(MultiArrayPickleWarpper arrayWarpper)
111+
{
112+
return arrayWarpper.reconstructedMultiArray;
113+
}
114+
public static implicit operator NDArray(MultiArrayPickleWarpper arrayWarpper)
115+
{
116+
return arrayWarpper.reconstructedNDArray;
117+
}
118+
}
119+
}

src/TensorFlowNET.Core/Numpy/Numpy.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public partial class np
4343
public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE;
4444
public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX;
4545
public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
46-
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
46+
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
47+
public static readonly TF_DataType @string = TF_DataType.TF_STRING;
48+
public static readonly TF_DataType @object = TF_DataType.TF_VARIANT;
4749
#endregion
4850

4951
public static double nan => double.NaN;

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ https://tensorflownet.readthedocs.io</Description>
176176
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
177177
<PackageReference Include="OneOf" Version="3.0.255" />
178178
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
179+
<PackageReference Include="Razorvine.Pickle" Version="1.4.0" />
179180
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
180181
</ItemGroup>
181182

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Razorvine.Pickle;
1718
using Serilog;
1819
using Serilog.Core;
1920
using System.Reflection;
@@ -22,6 +23,7 @@ limitations under the License.
2223
using Tensorflow.Eager;
2324
using Tensorflow.Gradients;
2425
using Tensorflow.Keras;
26+
using Tensorflow.NumPy.Pickle;
2527

2628
namespace Tensorflow
2729
{
@@ -98,6 +100,10 @@ public tensorflow()
98100
"please visit https://github.com/SciSharp/TensorFlow.NET. If it still not work after installing the backend, please submit an " +
99101
"issue to https://github.com/SciSharp/TensorFlow.NET/issues");
100102
}
103+
104+
// register numpy reconstructor for pickle
105+
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor());
106+
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor());
101107
}
102108

103109
public string VERSION => c_api.StringPiece(c_api.TF_Version());

0 commit comments

Comments
 (0)