Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit abd5cf7

Browse files
authored
Merge pull request #33 from saddam213/CoreApi
Core API Improvements
2 parents 8f53c9f + 9e9efde commit abd5cf7

19 files changed

+652
-529
lines changed

OnnxStack.Core/Extensions.cs renamed to OnnxStack.Core/Extensions/Extensions.cs

Lines changed: 0 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
using Microsoft.ML.OnnxRuntime;
2-
using Microsoft.ML.OnnxRuntime.Tensors;
32
using OnnxStack.Core.Config;
43
using System;
5-
using System.Buffers;
64
using System.Collections.Concurrent;
75
using System.Collections.Generic;
86
using System.Linq;
97
using System.Numerics;
10-
using System.Runtime.InteropServices;
118

129
namespace OnnxStack.Core
1310
{
@@ -238,135 +235,5 @@ public static long[] ToLong(this int[] array)
238235
{
239236
return Array.ConvertAll(array, Convert.ToInt64);
240237
}
241-
242-
243-
/// <summary>
244-
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
245-
/// </summary>
246-
/// <param name="tensor">The tensor.</param>
247-
/// <param name="nodeMetadata">The node metadata.</param>
248-
/// <returns></returns>
249-
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, NodeMetadata nodeMetadata)
250-
{
251-
var dimensions = tensor.Dimensions.ToLong();
252-
return nodeMetadata.ElementDataType switch
253-
{
254-
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
255-
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
256-
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
257-
};
258-
}
259-
260-
261-
/// <summary>
262-
/// Creates and allocates output tensors buffer.
263-
/// </summary>
264-
/// <param name="nodeMetadata">The node metadata.</param>
265-
/// <param name="dimensions">The dimensions.</param>
266-
/// <returns></returns>
267-
public static OrtValue CreateOutputBuffer(this NodeMetadata nodeMetadata, ReadOnlySpan<int> dimensions)
268-
{
269-
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, nodeMetadata.ElementDataType, dimensions.ToLong());
270-
}
271-
272-
273-
/// <summary>
274-
/// Converts to DenseTensor<float>.
275-
/// </summary>
276-
/// <param name="ortValue">The ort value.</param>
277-
/// <returns></returns>
278-
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
279-
{
280-
var typeInfo = ortValue.GetTensorTypeAndShape();
281-
var dimensions = typeInfo.Shape.ToInt();
282-
return typeInfo.ElementDataType switch
283-
{
284-
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),
285-
TensorElementType.BFloat16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat(), dimensions),
286-
_ => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<float>().ToArray(), dimensions)
287-
};
288-
}
289-
290-
291-
/// <summary>
292-
/// Converts to array.
293-
/// </summary>
294-
/// <param name="ortValue">The ort value.</param>
295-
/// <returns></returns>
296-
public static float[] ToArray(this OrtValue ortValue)
297-
{
298-
var typeInfo = ortValue.GetTensorTypeAndShape();
299-
var dimensions = typeInfo.Shape.ToInt();
300-
return typeInfo.ElementDataType switch
301-
{
302-
TensorElementType.Float16 => ortValue.GetTensorDataAsSpan<Float16>().ToFloat().ToArray(),
303-
TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat().ToArray(),
304-
_ => ortValue.GetTensorDataAsSpan<float>().ToArray()
305-
};
306-
}
307-
308-
309-
/// <summary>
310-
/// Converts to float16.
311-
/// </summary>
312-
/// <param name="inputMemory">The input memory.</param>
313-
/// <returns></returns>
314-
internal static Memory<Float16> ToFloat16(this Memory<float> inputMemory)
315-
{
316-
var elementCount = inputMemory.Length;
317-
var floatArray = new Float16[inputMemory.Length];
318-
for (int i = 0; i < elementCount; i++)
319-
floatArray[i] = (Float16)inputMemory.Span[i];
320-
321-
return floatArray.AsMemory();
322-
}
323-
324-
325-
/// <summary>
326-
/// Converts to BFloat16.
327-
/// </summary>
328-
/// <param name="inputMemory">The input memory.</param>
329-
/// <returns></returns>
330-
internal static Memory<BFloat16> ToBFloat16(this Memory<float> inputMemory)
331-
{
332-
var elementCount = inputMemory.Length;
333-
var floatArray = new BFloat16[inputMemory.Length];
334-
for (int i = 0; i < elementCount; i++)
335-
floatArray[i] = (BFloat16)inputMemory.Span[i];
336-
337-
return floatArray.AsMemory();
338-
}
339-
340-
341-
/// <summary>
342-
/// Converts to float.
343-
/// </summary>
344-
/// <param name="inputMemory">The input memory.</param>
345-
/// <returns></returns>
346-
internal static Memory<float> ToFloat(this ReadOnlySpan<Float16> inputMemory)
347-
{
348-
var elementCount = inputMemory.Length;
349-
var floatArray = new float[elementCount];
350-
for (int i = 0; i < elementCount; i++)
351-
floatArray[i] = (float)inputMemory[i];
352-
353-
return floatArray.AsMemory();
354-
}
355-
356-
357-
/// <summary>
358-
/// Converts to float.
359-
/// </summary>
360-
/// <param name="inputMemory">The input memory.</param>
361-
/// <returns></returns>
362-
internal static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
363-
{
364-
var elementCount = inputMemory.Length;
365-
var floatArray = new float[elementCount];
366-
for (int i = 0; i < elementCount; i++)
367-
floatArray[i] = (float)inputMemory[i];
368-
369-
return floatArray.AsMemory();
370-
}
371238
}
372239
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Model;
4+
using System;
5+
6+
namespace OnnxStack.Core
7+
{
8+
public static class OrtValueExtensions
9+
{
10+
/// <summary>
11+
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
12+
/// TODO: Optimization
13+
/// </summary>
14+
/// <param name="tensor">The tensor.</param>
15+
/// <param name="metadata">The metadata.</param>
16+
/// <returns></returns>
17+
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, OnnxNamedMetadata metadata)
18+
{
19+
var dimensions = tensor.Dimensions.ToLong();
20+
return metadata.Value.ElementDataType switch
21+
{
22+
TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToLong(), dimensions),
23+
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
24+
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
25+
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
26+
};
27+
}
28+
29+
30+
/// <summary>
31+
/// Converts DenseTensor<string> to OrtValue.
32+
/// </summary>
33+
/// <param name="tensor">The tensor.</param>
34+
/// <returns></returns>
35+
public static OrtValue ToOrtValue(this DenseTensor<string> tensor, OnnxNamedMetadata metadata)
36+
{
37+
return OrtValue.CreateFromStringTensor(tensor);
38+
}
39+
40+
41+
/// <summary>
42+
/// Converts DenseTensor<int> to OrtValue.
43+
/// </summary>
44+
/// <param name="tensor">The tensor.</param>
45+
/// <returns></returns>
46+
public static OrtValue ToOrtValue(this DenseTensor<int> tensor, OnnxNamedMetadata metadata)
47+
{
48+
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
49+
}
50+
51+
52+
/// <summary>
53+
/// Creates and allocates the output tensors buffer.
54+
/// </summary>
55+
/// <param name="metadata">The metadata.</param>
56+
/// <param name="dimensions">The dimensions.</param>
57+
/// <returns></returns>
58+
public static OrtValue CreateOutputBuffer(this OnnxNamedMetadata metadata, ReadOnlySpan<int> dimensions)
59+
{
60+
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, metadata.Value.ElementDataType, dimensions.ToLong());
61+
}
62+
63+
64+
/// <summary>
65+
/// Converts to DenseTensor<float>.
66+
/// TODO: Optimization
67+
/// </summary>
68+
/// <param name="ortValue">The ort value.</param>
69+
/// <returns></returns>
70+
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
71+
{
72+
var typeInfo = ortValue.GetTensorTypeAndShape();
73+
var dimensions = typeInfo.Shape.ToInt();
74+
return typeInfo.ElementDataType switch
75+
{
76+
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),
77+
TensorElementType.BFloat16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat(), dimensions),
78+
_ => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<float>().ToArray(), dimensions)
79+
};
80+
}
81+
82+
83+
/// <summary>
84+
/// Converts to array.
85+
/// TODO: Optimization
86+
/// </summary>
87+
/// <param name="ortValue">The ort value.</param>
88+
/// <returns></returns>
89+
public static float[] ToArray(this OrtValue ortValue)
90+
{
91+
var typeInfo = ortValue.GetTensorTypeAndShape();
92+
return typeInfo.ElementDataType switch
93+
{
94+
TensorElementType.Float16 => ortValue.GetTensorDataAsSpan<Float16>().ToFloat().ToArray(),
95+
TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat().ToArray(),
96+
_ => ortValue.GetTensorDataAsSpan<float>().ToArray()
97+
};
98+
}
99+
100+
101+
/// <summary>
102+
/// Converts to float16.
103+
/// TODO: Optimization
104+
/// </summary>
105+
/// <param name="inputMemory">The input memory.</param>
106+
/// <returns></returns>
107+
private static Memory<Float16> ToFloat16(this Memory<float> inputMemory)
108+
{
109+
var elementCount = inputMemory.Length;
110+
var floatArray = new Float16[inputMemory.Length];
111+
for (int i = 0; i < elementCount; i++)
112+
floatArray[i] = (Float16)inputMemory.Span[i];
113+
114+
return floatArray.AsMemory();
115+
}
116+
117+
118+
/// <summary>
119+
/// Converts to BFloat16.
120+
/// TODO: Optimization
121+
/// </summary>
122+
/// <param name="inputMemory">The input memory.</param>
123+
/// <returns></returns>
124+
private static Memory<BFloat16> ToBFloat16(this Memory<float> inputMemory)
125+
{
126+
var elementCount = inputMemory.Length;
127+
var floatArray = new BFloat16[inputMemory.Length];
128+
for (int i = 0; i < elementCount; i++)
129+
floatArray[i] = (BFloat16)inputMemory.Span[i];
130+
131+
return floatArray.AsMemory();
132+
}
133+
134+
135+
/// <summary>
136+
/// Converts to float.
137+
/// TODO: Optimization
138+
/// </summary>
139+
/// <param name="inputMemory">The input memory.</param>
140+
/// <returns></returns>
141+
private static Memory<float> ToFloat(this ReadOnlySpan<Float16> inputMemory)
142+
{
143+
var elementCount = inputMemory.Length;
144+
var floatArray = new float[elementCount];
145+
for (int i = 0; i < elementCount; i++)
146+
floatArray[i] = (float)inputMemory[i];
147+
148+
return floatArray.AsMemory();
149+
}
150+
151+
152+
/// <summary>
153+
/// Converts to float.
154+
/// TODO: Optimization
155+
/// </summary>
156+
/// <param name="inputMemory">The input memory.</param>
157+
/// <returns></returns>
158+
private static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
159+
{
160+
var elementCount = inputMemory.Length;
161+
var floatArray = new float[elementCount];
162+
for (int i = 0; i < elementCount; i++)
163+
floatArray[i] = (float)inputMemory[i];
164+
165+
return floatArray.AsMemory();
166+
}
167+
168+
169+
/// <summary>
170+
/// Converts to long.
171+
/// </summary>
172+
/// <param name="inputMemory">The input memory.</param>
173+
/// <returns></returns>
174+
private static Memory<long> ToLong(this Memory<float> inputMemory)
175+
{
176+
return Array.ConvertAll(inputMemory.ToArray(), Convert.ToInt64).AsMemory();
177+
}
178+
}
179+
}

0 commit comments

Comments
 (0)