Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Core API Improvements #33

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Config;
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Runtime.InteropServices;

namespace OnnxStack.Core
{
Expand Down Expand Up @@ -238,135 +235,5 @@ public static long[] ToLong(this int[] array)
{
return Array.ConvertAll(array, Convert.ToInt64);
}


/// <summary>
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
/// </summary>
/// <param name="tensor">The tensor.</param>
/// <param name="nodeMetadata">The node metadata.</param>
/// <returns></returns>
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, NodeMetadata nodeMetadata)
{
var dimensions = tensor.Dimensions.ToLong();
return nodeMetadata.ElementDataType switch
{
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
};
}


/// <summary>
/// Creates and allocates output tensors buffer.
/// </summary>
/// <param name="nodeMetadata">The node metadata.</param>
/// <param name="dimensions">The dimensions.</param>
/// <returns></returns>
public static OrtValue CreateOutputBuffer(this NodeMetadata nodeMetadata, ReadOnlySpan<int> dimensions)
{
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, nodeMetadata.ElementDataType, dimensions.ToLong());
}


/// <summary>
/// Converts to DenseTensor<float>.
/// </summary>
/// <param name="ortValue">The ort value.</param>
/// <returns></returns>
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
{
var typeInfo = ortValue.GetTensorTypeAndShape();
var dimensions = typeInfo.Shape.ToInt();
return typeInfo.ElementDataType switch
{
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),
TensorElementType.BFloat16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat(), dimensions),
_ => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<float>().ToArray(), dimensions)
};
}


/// <summary>
/// Converts to array.
/// </summary>
/// <param name="ortValue">The ort value.</param>
/// <returns></returns>
public static float[] ToArray(this OrtValue ortValue)
{
var typeInfo = ortValue.GetTensorTypeAndShape();
var dimensions = typeInfo.Shape.ToInt();
return typeInfo.ElementDataType switch
{
TensorElementType.Float16 => ortValue.GetTensorDataAsSpan<Float16>().ToFloat().ToArray(),
TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat().ToArray(),
_ => ortValue.GetTensorDataAsSpan<float>().ToArray()
};
}


/// <summary>
/// Converts to float16.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<Float16> ToFloat16(this Memory<float> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new Float16[inputMemory.Length];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (Float16)inputMemory.Span[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to BFloat16.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<BFloat16> ToBFloat16(this Memory<float> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new BFloat16[inputMemory.Length];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (BFloat16)inputMemory.Span[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to float.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<float> ToFloat(this ReadOnlySpan<Float16> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new float[elementCount];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (float)inputMemory[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to float.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new float[elementCount];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (float)inputMemory[i];

return floatArray.AsMemory();
}
}
}
179 changes: 179 additions & 0 deletions OnnxStack.Core/Extensions/OrtValueExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Model;
using System;

namespace OnnxStack.Core
{
public static class OrtValueExtensions
{
/// <summary>
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
/// TODO: Optimization
/// </summary>
/// <param name="tensor">The tensor.</param>
/// <param name="metadata">The metadata.</param>
/// <returns></returns>
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, OnnxNamedMetadata metadata)
{
var dimensions = tensor.Dimensions.ToLong();
return metadata.Value.ElementDataType switch
{
TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToLong(), dimensions),
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
};
}


/// <summary>
/// Converts DenseTensor<string> to OrtValue.
/// </summary>
/// <param name="tensor">The tensor.</param>
/// <returns></returns>
public static OrtValue ToOrtValue(this DenseTensor<string> tensor, OnnxNamedMetadata metadata)
{
return OrtValue.CreateFromStringTensor(tensor);
}


/// <summary>
/// Converts DenseTensor<int> to OrtValue.
/// </summary>
/// <param name="tensor">The tensor.</param>
/// <returns></returns>
public static OrtValue ToOrtValue(this DenseTensor<int> tensor, OnnxNamedMetadata metadata)
{
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
}


/// <summary>
/// Creates and allocates the output tensors buffer.
/// </summary>
/// <param name="metadata">The metadata.</param>
/// <param name="dimensions">The dimensions.</param>
/// <returns></returns>
public static OrtValue CreateOutputBuffer(this OnnxNamedMetadata metadata, ReadOnlySpan<int> dimensions)
{
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, metadata.Value.ElementDataType, dimensions.ToLong());
}


/// <summary>
/// Converts to DenseTensor<float>.
/// TODO: Optimization
/// </summary>
/// <param name="ortValue">The ort value.</param>
/// <returns></returns>
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
{
var typeInfo = ortValue.GetTensorTypeAndShape();
var dimensions = typeInfo.Shape.ToInt();
return typeInfo.ElementDataType switch
{
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),
TensorElementType.BFloat16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat(), dimensions),
_ => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<float>().ToArray(), dimensions)
};
}


/// <summary>
/// Converts to array.
/// TODO: Optimization
/// </summary>
/// <param name="ortValue">The ort value.</param>
/// <returns></returns>
public static float[] ToArray(this OrtValue ortValue)
{
var typeInfo = ortValue.GetTensorTypeAndShape();
return typeInfo.ElementDataType switch
{
TensorElementType.Float16 => ortValue.GetTensorDataAsSpan<Float16>().ToFloat().ToArray(),
TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat().ToArray(),
_ => ortValue.GetTensorDataAsSpan<float>().ToArray()
};
}


/// <summary>
/// Converts to float16.
/// TODO: Optimization
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
private static Memory<Float16> ToFloat16(this Memory<float> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new Float16[inputMemory.Length];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (Float16)inputMemory.Span[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to BFloat16.
/// TODO: Optimization
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
private static Memory<BFloat16> ToBFloat16(this Memory<float> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new BFloat16[inputMemory.Length];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (BFloat16)inputMemory.Span[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to float.
/// TODO: Optimization
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
private static Memory<float> ToFloat(this ReadOnlySpan<Float16> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new float[elementCount];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (float)inputMemory[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to float.
/// TODO: Optimization
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
private static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new float[elementCount];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (float)inputMemory[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to long.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
private static Memory<long> ToLong(this Memory<float> inputMemory)
{
return Array.ConvertAll(inputMemory.ToArray(), Convert.ToInt64).AsMemory();
}
}
}
Loading