Skip to content

Commit 875063a

Browse files
AsakusaRinneOceania2018
authored andcommitted
Add more explicit conversion for Tensors.
1 parent 1767c3c commit 875063a

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,93 @@ public void Insert(int index, Tensor tensor)
6565
IEnumerator IEnumerable.GetEnumerator()
6666
=> GetEnumerator();
6767

68+
public NDArray numpy()
69+
{
70+
EnsureSingleTensor(this, "nnumpy");
71+
return this[0].numpy();
72+
}
73+
74+
public T[] ToArray<T>() where T: unmanaged
75+
{
76+
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
77+
return this[0].ToArray<T>();
78+
}
79+
80+
#region Explicit Conversions
81+
public unsafe static explicit operator bool(Tensors tensor)
82+
{
83+
EnsureSingleTensor(tensor, "explicit conversion to bool");
84+
return (bool)tensor[0];
85+
}
86+
87+
public unsafe static explicit operator sbyte(Tensors tensor)
88+
{
89+
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
90+
return (sbyte)tensor[0];
91+
}
92+
93+
public unsafe static explicit operator byte(Tensors tensor)
94+
{
95+
EnsureSingleTensor(tensor, "explicit conversion to byte");
96+
return (byte)tensor[0];
97+
}
98+
99+
public unsafe static explicit operator ushort(Tensors tensor)
100+
{
101+
EnsureSingleTensor(tensor, "explicit conversion to ushort");
102+
return (ushort)tensor[0];
103+
}
104+
105+
public unsafe static explicit operator short(Tensors tensor)
106+
{
107+
EnsureSingleTensor(tensor, "explicit conversion to short");
108+
return (short)tensor[0];
109+
}
110+
111+
public unsafe static explicit operator int(Tensors tensor)
112+
{
113+
EnsureSingleTensor(tensor, "explicit conversion to int");
114+
return (int)tensor[0];
115+
}
116+
117+
public unsafe static explicit operator uint(Tensors tensor)
118+
{
119+
EnsureSingleTensor(tensor, "explicit conversion to uint");
120+
return (uint)tensor[0];
121+
}
122+
123+
public unsafe static explicit operator long(Tensors tensor)
124+
{
125+
EnsureSingleTensor(tensor, "explicit conversion to long");
126+
return (long)tensor[0];
127+
}
128+
129+
public unsafe static explicit operator ulong(Tensors tensor)
130+
{
131+
EnsureSingleTensor(tensor, "explicit conversion to ulong");
132+
return (ulong)tensor[0];
133+
}
134+
135+
public unsafe static explicit operator float(Tensors tensor)
136+
{
137+
EnsureSingleTensor(tensor, "explicit conversion to byte");
138+
return (byte)tensor[0];
139+
}
140+
141+
public unsafe static explicit operator double(Tensors tensor)
142+
{
143+
EnsureSingleTensor(tensor, "explicit conversion to double");
144+
return (double)tensor[0];
145+
}
146+
147+
public unsafe static explicit operator string(Tensors tensor)
148+
{
149+
EnsureSingleTensor(tensor, "explicit conversion to string");
150+
return (string)tensor[0];
151+
}
152+
#endregion
153+
154+
#region Implicit Conversions
68155
public static implicit operator Tensors(Tensor tensor)
69156
=> new Tensors(tensor);
70157

@@ -87,12 +174,26 @@ public static implicit operator Tensor(Tensors tensors)
87174
public static implicit operator Tensor[](Tensors tensors)
88175
=> tensors.items.ToArray();
89176

177+
#endregion
178+
90179
public void Deconstruct(out Tensor a, out Tensor b)
91180
{
92181
a = items[0];
93182
b = items[1];
94183
}
95184

185+
private static void EnsureSingleTensor(Tensors tensors, string methodnName)
186+
{
187+
if(tensors.Length == 0)
188+
{
189+
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
190+
}
191+
else if(tensors.Length > 1)
192+
{
193+
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
194+
}
195+
}
196+
96197
public override string ToString()
97198
=> items.Count() == 1
98199
? items.First().ToString()

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public void Range()
2020
Assert.AreEqual(iStep, step);
2121
iStep++;
2222

23-
Assert.AreEqual(value, (long)item.Item1[0]);
23+
Assert.AreEqual(value, (long)item.Item1);
2424
value++;
2525
}
2626
}
@@ -39,7 +39,7 @@ public void Prefetch()
3939
Assert.AreEqual(iStep, step);
4040
iStep++;
4141

42-
Assert.AreEqual(value, (long)item.Item1[0]);
42+
Assert.AreEqual(value, (long)item.Item1);
4343
value += 2;
4444
}
4545
}
@@ -54,7 +54,7 @@ public void FromTensorSlices()
5454
int n = 0;
5555
foreach (var (item_x, item_y) in dataset)
5656
{
57-
print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}");
57+
print($"x:{item_x.numpy()},y:{item_y.numpy()}");
5858
n += 1;
5959
}
6060
Assert.AreEqual(5, n);
@@ -69,7 +69,7 @@ public void FromTensor()
6969
int n = 0;
7070
foreach (var x in dataset)
7171
{
72-
Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray<int>()));
72+
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
7373
n += 1;
7474
}
7575
Assert.AreEqual(1, n);
@@ -85,15 +85,15 @@ public void Shard()
8585

8686
foreach (var item in dataset2)
8787
{
88-
Assert.AreEqual(value, (long)item.Item1[0]);
88+
Assert.AreEqual(value, (long)item.Item1);
8989
value += 3;
9090
}
9191

9292
value = 1;
9393
var dataset3 = dataset1.shard(num_shards: 3, index: 1);
9494
foreach (var item in dataset3)
9595
{
96-
Assert.AreEqual(value, (long)item.Item1[0]);
96+
Assert.AreEqual(value, (long)item.Item1);
9797
value += 3;
9898
}
9999
}
@@ -108,7 +108,7 @@ public void Skip()
108108

109109
foreach (var item in dataset)
110110
{
111-
Assert.AreEqual(value, (long)item.Item1[0]);
111+
Assert.AreEqual(value, (long)item.Item1);
112112
value++;
113113
}
114114
}
@@ -123,7 +123,7 @@ public void Map()
123123

124124
foreach (var item in dataset)
125125
{
126-
Assert.AreEqual(value + 10, (long)item.Item1[0]);
126+
Assert.AreEqual(value + 10, (long)item.Item1);
127127
value++;
128128
}
129129
}
@@ -138,7 +138,7 @@ public void Cache()
138138

139139
foreach (var item in dataset)
140140
{
141-
Assert.AreEqual(value, (long)item.Item1[0]);
141+
Assert.AreEqual(value, (long)item.Item1);
142142
value++;
143143
}
144144
}

0 commit comments

Comments
 (0)