Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Improve ModelSet runtime management #63

Merged
merged 3 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions OnnxStack.Console/Examples/StableDebug.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public async Task RunAsync()
Strength = 0.6f
};

foreach (var model in _stableDiffusionService.Models)
foreach (var model in _stableDiffusionService.ModelSets)
{
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModelAsync(model);
Expand All @@ -71,7 +71,7 @@ public async Task RunAsync()
}


private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options)
{
var timestamp = Stopwatch.GetTimestamp();
var outputFilename = Path.Combine(_outputDirectory, $"{model.Name}_{options.Seed}_{options.SchedulerType}.png");
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.Console/Examples/StableDiffusionBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public async Task RunAsync()
BatchType = BatchOptionType.Scheduler
};

foreach (var model in _stableDiffusionService.Models)
foreach (var model in _stableDiffusionService.ModelSets)
{
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModelAsync(model);
Expand Down
4 changes: 2 additions & 2 deletions OnnxStack.Console/Examples/StableDiffusionExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task RunAsync()
Seed = Random.Shared.Next()
};

foreach (var model in _stableDiffusionService.Models)
foreach (var model in _stableDiffusionService.ModelSets)
{
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModelAsync(model);
Expand All @@ -65,7 +65,7 @@ public async Task RunAsync()
}
}

private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options)
{
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
Expand Down
4 changes: 2 additions & 2 deletions OnnxStack.Console/Examples/StableDiffusionGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public async Task RunAsync()
Directory.CreateDirectory(_outputDirectory);

var seed = Random.Shared.Next();
foreach (var model in _stableDiffusionService.Models)
foreach (var model in _stableDiffusionService.ModelSets)
{
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModelAsync(model);
Expand Down Expand Up @@ -62,7 +62,7 @@ public async Task RunAsync()
OutputHelpers.ReadConsole(ConsoleColor.Gray);
}

private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options, string key)
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, string key)
{
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png");
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.Console/Examples/StableDiffusionGif.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public async Task RunAsync()
};

// Choose Model
var model = _stableDiffusionService.Models.FirstOrDefault(x => x.Name == "LCM-Dreamshaper-V7");
var model = _stableDiffusionService.ModelSets.FirstOrDefault(x => x.Name == "LCM-Dreamshaper-V7");
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModelAsync(model);

Expand Down
4 changes: 2 additions & 2 deletions OnnxStack.Console/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
}
},
"AllowedHosts": "*",
"OnnxStackConfig": {
"OnnxModelSets": [
"StableDiffusionConfig": {
"ModelSets": [
{
"Name": "StableDiffusion 1.5",
"IsEnabled": true,
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.Core/Config/IOnnxModelSetConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ public interface IOnnxModelSetConfig : IOnnxModel
int IntraOpNumThreads { get; set; }
ExecutionMode ExecutionMode { get; set; }
ExecutionProvider ExecutionProvider { get; set; }
List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
List<OnnxModelConfig> ModelConfigurations { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace OnnxStack.Core.Config
{
public class OnnxModelSessionConfig
public class OnnxModelConfig
{
public OnnxModelType Type { get; set; }
public string OnnxModelPath { get; set; }
Expand Down
17 changes: 17 additions & 0 deletions OnnxStack.Core/Config/OnnxModelEqualityComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System.Collections.Generic;

namespace OnnxStack.Core.Config
{
public class OnnxModelEqualityComparer : IEqualityComparer<IOnnxModel>
{
public bool Equals(IOnnxModel x, IOnnxModel y)
{
return x != null && y != null && x.Name == y.Name;
}

public int GetHashCode(IOnnxModel obj)
{
return obj?.Name?.GetHashCode() ?? 0;
}
}
}
2 changes: 1 addition & 1 deletion OnnxStack.Core/Config/OnnxModelSetConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ public class OnnxModelSetConfig : IOnnxModelSetConfig
public int IntraOpNumThreads { get; set; }
public ExecutionMode ExecutionMode { get; set; }
public ExecutionProvider ExecutionProvider { get; set; }
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
public List<OnnxModelConfig> ModelConfigurations { get; set; }
}
}
11 changes: 0 additions & 11 deletions OnnxStack.Core/Config/OnnxStackConfig.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
using OnnxStack.Common.Config;
using System.Collections.Generic;
using System.Linq;

namespace OnnxStack.Core.Config
{
public class OnnxStackConfig : IConfigSection
{
public List<OnnxModelSetConfig> OnnxModelSets { get; set; } = new List<OnnxModelSetConfig>();

public void Initialize()
{
if (OnnxModelSets.IsNullOrEmpty())
return;

foreach (var modelSet in OnnxModelSets)
{
modelSet.ApplyConfigurationOverrides();
}
}
}
}
2 changes: 1 addition & 1 deletion OnnxStack.Core/Extensions/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace OnnxStack.Core
{
public static class Extensions
{
public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig configuration)
public static SessionOptions GetSessionOptions(this OnnxModelConfig configuration)
{
var sessionOptions = new SessionOptions
{
Expand Down
6 changes: 3 additions & 3 deletions OnnxStack.Core/Model/OnnxModelSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ public class OnnxModelSession : IDisposable
{
private readonly SessionOptions _options;
private readonly InferenceSession _session;
private readonly OnnxModelSessionConfig _configuration;
private readonly OnnxModelConfig _configuration;

/// <summary>
/// Initializes a new instance of the <see cref="OnnxModelSession"/> class.
/// </summary>
/// <param name="configuration">The configuration.</param>
/// <param name="container">The container.</param>
/// <exception cref="System.IO.FileNotFoundException">Onnx model file not found</exception>
public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsContainer container)
public OnnxModelSession(OnnxModelConfig configuration, PrePackedWeightsContainer container)
{
if (!File.Exists(configuration.OnnxModelPath))
throw new FileNotFoundException("Onnx model file not found", configuration.OnnxModelPath);
Expand All @@ -44,7 +44,7 @@ public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsCo
/// <summary>
/// Gets the configuration.
/// </summary>
public OnnxModelSessionConfig Configuration => _configuration;
public OnnxModelConfig Configuration => _configuration;


/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.Core/Model/OnnxModelSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public InferenceSession GetSession(OnnxModelType modelType)
/// </summary>
/// <param name="modelType">Type of the model.</param>
/// <returns></returns>
public OnnxModelSessionConfig GetConfiguration(OnnxModelType modelType)
public OnnxModelConfig GetConfiguration(OnnxModelType modelType)
{
return _configuration.ModelConfigurations.FirstOrDefault(x => x.Type == modelType);
}
Expand Down
19 changes: 18 additions & 1 deletion OnnxStack.Core/Registration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static class Registration
/// <param name="serviceCollection">The service collection.</param>
public static void AddOnnxStack(this IServiceCollection serviceCollection)
{
serviceCollection.AddSingleton(ConfigManager.LoadConfiguration());
serviceCollection.AddSingleton(TryLoadAppSettings());
serviceCollection.AddSingleton<IOnnxModelService, OnnxModelService>();
}

Expand All @@ -43,5 +43,22 @@ public static void AddOnnxStackConfig<T>(this IServiceCollection serviceCollecti
{
serviceCollection.AddSingleton(ConfigManager.LoadConfiguration<T>());
}


/// <summary>
/// Try load OnnxStackConfig from application settings if it exists.
/// </summary>
/// <returns></returns>
private static OnnxStackConfig TryLoadAppSettings()
{
try
{
return ConfigManager.LoadConfiguration<OnnxStackConfig>();
}
catch
{
return new OnnxStackConfig();
}
}
}
}
16 changes: 7 additions & 9 deletions OnnxStack.Core/Services/IOnnxModelService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,26 @@ public interface IOnnxModelService : IDisposable
/// <returns></returns>
Task<bool> AddModelSet(IOnnxModelSetConfig modelSet);


/// <summary>
/// Adds a collection of ModelSet
/// </summary>
/// <param name="modelSets">The model sets.</param>
Task AddModelSet(IEnumerable<IOnnxModelSetConfig> modelSets);


/// <summary>
/// Removes a model set.
/// </summary>
/// <param name="modelSet">The model set.</param>
/// <returns></returns>
Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet);

/// <summary>
/// Updates the model set.
/// </summary>
/// <param name="modelSet">The model set.</param>
/// <returns></returns>
Task<bool> UpdateModelSet(IOnnxModelSetConfig modelSet);

/// <summary>
/// Loads the model.
/// </summary>
Expand All @@ -65,13 +70,6 @@ public interface IOnnxModelService : IDisposable
bool IsModelLoaded(IOnnxModel model);


/// <summary>
/// Updates the model set.
/// </summary>
/// <param name="modelSet">The model set.</param>
/// <returns></returns>
bool UpdateModelSet(IOnnxModelSetConfig modelSet);

/// <summary>
/// Determines whether the specified model type is enabled.
/// </summary>
Expand Down
40 changes: 18 additions & 22 deletions OnnxStack.Core/Services/OnnxModelService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace OnnxStack.Core.Services
public sealed class OnnxModelService : IOnnxModelService
{
private readonly OnnxStackConfig _configuration;
private readonly ConcurrentDictionary<string, OnnxModelSet> _onnxModelSets;
private readonly ConcurrentDictionary<string, IOnnxModelSetConfig> _onnxModelSetConfigs;
private readonly ConcurrentDictionary<IOnnxModel, OnnxModelSet> _onnxModelSets;
private readonly ConcurrentDictionary<IOnnxModel, IOnnxModelSetConfig> _onnxModelSetConfigs;

/// <summary>
/// Initializes a new instance of the <see cref="OnnxModelService"/> class.
Expand All @@ -26,8 +26,8 @@ public sealed class OnnxModelService : IOnnxModelService
public OnnxModelService(OnnxStackConfig configuration)
{
_configuration = configuration;
_onnxModelSets = new ConcurrentDictionary<string, OnnxModelSet>();
_onnxModelSetConfigs = _configuration.OnnxModelSets.ToConcurrentDictionary(x => x.Name, x => x as IOnnxModelSetConfig);
_onnxModelSets = new ConcurrentDictionary<IOnnxModel, OnnxModelSet>(new OnnxModelEqualityComparer());
_onnxModelSetConfigs = new ConcurrentDictionary<IOnnxModel, IOnnxModelSetConfig>(new OnnxModelEqualityComparer());
}


Expand All @@ -50,7 +50,7 @@ public OnnxModelService(OnnxStackConfig configuration)
/// <returns></returns>
public Task<bool> AddModelSet(IOnnxModelSetConfig modelSet)
{
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet));
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet));
}

/// <summary>
Expand All @@ -74,7 +74,7 @@ public Task AddModelSet(IEnumerable<IOnnxModelSetConfig> modelSets)
/// <returns></returns>
public Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet)
{
return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet.Name, out _));
return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet, out _));
}


Expand All @@ -83,10 +83,10 @@ public Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet)
/// </summary>
/// <param name="modelSet">The model set.</param>
/// <returns></returns>
public bool UpdateModelSet(IOnnxModelSetConfig modelSet)
public Task<bool> UpdateModelSet(IOnnxModelSetConfig modelSet)
{
_onnxModelSetConfigs.TryRemove(modelSet.Name, out _);
return _onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet);
_onnxModelSetConfigs.TryRemove(modelSet, out _);
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet));
}


Expand Down Expand Up @@ -120,7 +120,7 @@ public async Task<bool> UnloadModelAsync(IOnnxModel model)
/// </returns>
public bool IsModelLoaded(IOnnxModel model)
{
return _onnxModelSets.ContainsKey(model.Name);
return _onnxModelSets.ContainsKey(model);
}


Expand Down Expand Up @@ -251,7 +251,7 @@ private OnnxMetadata GetNodeMetadataInternal(IOnnxModel model, OnnxModelType mod
/// <exception cref="System.Exception">Model {model.Name} has not been loaded</exception>
private OnnxModelSet GetModelSet(IOnnxModel model)
{
if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet))
if (!_onnxModelSets.TryGetValue(model, out var modelSet))
throw new Exception($"Model {model.Name} has not been loaded");

return modelSet;
Expand All @@ -266,17 +266,17 @@ private OnnxModelSet GetModelSet(IOnnxModel model)
/// <exception cref="System.Exception">Model {model.Name} not found in configuration</exception>
private OnnxModelSet LoadModelSet(IOnnxModel model)
{
if (_onnxModelSets.ContainsKey(model.Name))
return _onnxModelSets[model.Name];
if (_onnxModelSets.ContainsKey(model))
return _onnxModelSets[model];

if (!_onnxModelSetConfigs.TryGetValue(model.Name, out var modelSetConfig))
throw new Exception($"Model {model.Name} not found in configuration");
if (!_onnxModelSetConfigs.TryGetValue(model, out var modelSetConfig))
throw new Exception($"Model {model.Name} not found");

if (!modelSetConfig.IsEnabled)
throw new Exception($"Model {model.Name} is not enabled");

var modelSet = new OnnxModelSet(modelSetConfig);
_onnxModelSets.TryAdd(model.Name, modelSet);
_onnxModelSets.TryAdd(model, modelSet);
return modelSet;
}

Expand All @@ -288,10 +288,10 @@ private OnnxModelSet LoadModelSet(IOnnxModel model)
/// <returns></returns>
private bool UnloadModelSet(IOnnxModel model)
{
if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet))
if (!_onnxModelSets.TryGetValue(model, out _))
return true;

if (_onnxModelSets.TryRemove(model.Name, out modelSet))
if (_onnxModelSets.TryRemove(model, out var modelSet))
{
modelSet?.Dispose();
return true;
Expand All @@ -310,9 +310,5 @@ public void Dispose()
onnxModelSet?.Dispose();
}
}


}


}
2 changes: 1 addition & 1 deletion OnnxStack.ImageUpscaler/Config/UpscaleModelSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ public class UpscaleModelSet : IOnnxModelSetConfig
public int IntraOpNumThreads { get; set; }
public ExecutionMode ExecutionMode { get; set; }
public ExecutionProvider ExecutionProvider { get; set; }
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
public List<OnnxModelConfig> ModelConfigurations { get; set; }
}
}
Loading