Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

LoRA loading and LoRA application #20

Closed
wants to merge 9 commits into from
Closed
42 changes: 42 additions & 0 deletions OnnxStack.Console/Examples/LoRADebug.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Services;

namespace OnnxStack.Console.Runner
{
public sealed class LoRADebug : IExampleRunner
{
private readonly string _outputDirectory;
private readonly IOnnxModelService _modelService;
private readonly IOnnxModelAdaptaterService _modelAdaptaterService;

public LoRADebug(IOnnxModelService modelService)
{
_modelService = modelService;
_outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(StableDebug));
}

public string Name => "LoRA Debug";

public string Description => "LoRA Debugger";

public async Task RunAsync()
{
string modelPath = "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx";
string loraModelPath = "D:\\Repositories\\LoRAFiles\\model.onnx";

using (var modelession = new InferenceSession(modelPath))
using (var loraModelSession = new InferenceSession(loraModelPath))
{
try
{
_modelAdaptaterService.ApplyLowRankAdaptation(modelession, loraModelSession);
}
catch (Exception ex)
{

}
}
}

}
}
14 changes: 14 additions & 0 deletions OnnxStack.Core/Model/OnnxModelAdapter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using OnnxStack.Core.Config;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace OnnxStack.Core.Model
{
public class OnnxModelAdapter : IOnnxModel
{
public string Name { get; set; }
}
}
1 change: 1 addition & 0 deletions OnnxStack.Core/Registration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static void AddOnnxStack(this IServiceCollection serviceCollection)
{
serviceCollection.AddSingleton(ConfigManager.LoadConfiguration());
serviceCollection.AddSingleton<IOnnxModelService, OnnxModelService>();
serviceCollection.AddSingleton<IOnnxModelAdaptaterService, OnnxModelAdaptaterService>();
}


Expand Down
9 changes: 9 additions & 0 deletions OnnxStack.Core/Services/IOnnxModelAdaptaterService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Microsoft.ML.OnnxRuntime;

namespace OnnxStack.Core.Services
{
public interface IOnnxModelAdaptaterService
{
void ApplyLowRankAdaptation(InferenceSession primarySession, InferenceSession loraSession);
}
}
50 changes: 50 additions & 0 deletions OnnxStack.Core/Services/OnnxModelAdaptaterService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Linq;

namespace OnnxStack.Core.Services
{
public class OnnxModelAdaptaterService : IOnnxModelAdaptaterService
{
public void ApplyLowRankAdaptation(InferenceSession primarySession, InferenceSession loraSession)
{
// For simplicity, let's assume we will replace the weights of the first dense layer
string layerName = "layer_name";

// Get the current weights from the primary model
var primaryInputName = primarySession.InputMetadata.Keys.First();
var primaryInputTensor = primarySession.InputMetadata[primaryInputName];
var primaryWeights = new float[primaryInputTensor.Dimensions.Product()];

// Get the weights from the LoRA model
var lraInputName = loraSession.InputMetadata.Keys.First();
var lraInputTensor = loraSession.InputMetadata[lraInputName];
var lraWeights = new float[lraInputTensor.Dimensions.Product()];

// Apply LoRA (replace weights) this is where we will do the mutiplication of the weights
// but for testing sake just brute for replacing
Array.Copy(lraWeights, primaryWeights, Math.Min(primaryWeights.Length, lraWeights.Length));

// Update the primary model tensor with the modified weights
var tensor = new DenseTensor<float>(primaryWeights, primaryInputTensor.Dimensions.ToArray());
var inputs = new NamedOnnxValue[] { NamedOnnxValue.CreateFromTensor(primaryInputName, tensor) };

// Will it run?
primarySession.Run(inputs);
}
}

public static class Ext
{
public static int Product(this int[] array)
{
int result = 1;
foreach (int element in array)
{
result *= element;
}
return result;
}
}
}