diff --git a/OnnxStack.Console/Examples/LoRADebug.cs b/OnnxStack.Console/Examples/LoRADebug.cs new file mode 100644 index 00000000..566458c5 --- /dev/null +++ b/OnnxStack.Console/Examples/LoRADebug.cs @@ -0,0 +1,42 @@ +using Microsoft.ML.OnnxRuntime; +using OnnxStack.Core.Services; + +namespace OnnxStack.Console.Runner +{ + public sealed class LoRADebug : IExampleRunner + { + private readonly string _outputDirectory; + private readonly IOnnxModelService _modelService; + private readonly IOnnxModelAdaptaterService _modelAdaptaterService; + + public LoRADebug(IOnnxModelService modelService) + { + _modelService = modelService; + _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(StableDebug)); + } + + public string Name => "LoRA Debug"; + + public string Description => "LoRA Debugger"; + + public async Task RunAsync() + { + string modelPath = "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx"; + string loraModelPath = "D:\\Repositories\\LoRAFiles\\model.onnx"; + + using (var modelession = new InferenceSession(modelPath)) + using (var loraModelSession = new InferenceSession(loraModelPath)) + { + try + { + _modelAdaptaterService.ApplyLowRankAdaptation(modelession, loraModelSession); + } + catch (Exception ex) + { + + } + } + } + + } +} diff --git a/OnnxStack.Core/Model/OnnxModelAdapter.cs b/OnnxStack.Core/Model/OnnxModelAdapter.cs new file mode 100644 index 00000000..39382b75 --- /dev/null +++ b/OnnxStack.Core/Model/OnnxModelAdapter.cs @@ -0,0 +1,14 @@ +using OnnxStack.Core.Config; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace OnnxStack.Core.Model +{ + public class OnnxModelAdapter : IOnnxModel + { + public string Name { get; set; } + } +} diff --git a/OnnxStack.Core/Registration.cs b/OnnxStack.Core/Registration.cs index fbcd9039..521b31ad 100644 --- a/OnnxStack.Core/Registration.cs +++ b/OnnxStack.Core/Registration.cs @@ -18,6 +18,7 @@ public static void AddOnnxStack(this IServiceCollection serviceCollection) { serviceCollection.AddSingleton(ConfigManager.LoadConfiguration()); serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); } diff --git a/OnnxStack.Core/Services/IOnnxModelAdaptaterService.cs b/OnnxStack.Core/Services/IOnnxModelAdaptaterService.cs new file mode 100644 index 00000000..0505ad43 --- /dev/null +++ b/OnnxStack.Core/Services/IOnnxModelAdaptaterService.cs @@ -0,0 +1,9 @@ +using Microsoft.ML.OnnxRuntime; + +namespace OnnxStack.Core.Services +{ + public interface IOnnxModelAdaptaterService + { + void ApplyLowRankAdaptation(InferenceSession primarySession, InferenceSession loraSession); + } +} \ No newline at end of file diff --git a/OnnxStack.Core/Services/OnnxModelAdaptaterService.cs b/OnnxStack.Core/Services/OnnxModelAdaptaterService.cs new file mode 100644 index 00000000..92be511b --- /dev/null +++ b/OnnxStack.Core/Services/OnnxModelAdaptaterService.cs @@ -0,0 +1,50 @@ +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; +using System; +using System.Linq; + +namespace OnnxStack.Core.Services +{ + public class OnnxModelAdaptaterService : IOnnxModelAdaptaterService + { + public void ApplyLowRankAdaptation(InferenceSession primarySession, InferenceSession loraSession) + { + // For simplicity, let's assume we will replace the weights of the first dense layer + string layerName = "layer_name"; + + // Get the current weights from the primary model + var primaryInputName = primarySession.InputMetadata.Keys.First(); + var primaryInputTensor = primarySession.InputMetadata[primaryInputName]; + var primaryWeights = new float[primaryInputTensor.Dimensions.Product()]; + + // Get the weights from the LoRA model + var lraInputName = loraSession.InputMetadata.Keys.First(); + var lraInputTensor = loraSession.InputMetadata[lraInputName]; + var lraWeights = new float[lraInputTensor.Dimensions.Product()]; + + // Apply LoRA (replace weights) this is where we will do the mutiplication of the weights + // but for testing sake just brute for replacing + Array.Copy(lraWeights, primaryWeights, Math.Min(primaryWeights.Length, lraWeights.Length)); + + // Update the primary model tensor with the modified weights + var tensor = new DenseTensor(primaryWeights, primaryInputTensor.Dimensions.ToArray()); + var inputs = new NamedOnnxValue[] { NamedOnnxValue.CreateFromTensor(primaryInputName, tensor) }; + + // Will it run? + primarySession.Run(inputs); + } + } + + public static class Ext + { + public static int Product(this int[] array) + { + int result = 1; + foreach (int element in array) + { + result *= element; + } + return result; + } + } +}