Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit fc9ef24

Browse files
authored
Merge pull request #88 from saddam213/ControlNet
Add ControlNet Support
2 parents 3c12247 + cd637a3 commit fc9ef24

File tree

79 files changed

+1937
-356
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1937
-356
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOpti
7777
{
7878
var timestamp = Stopwatch.GetTimestamp();
7979
var outputFilename = Path.Combine(_outputDirectory, $"{model.Name}_{options.Seed}_{options.SchedulerType}.png");
80-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
80+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), prompt, options);
8181
if (result is not null)
8282
{
8383
await result.SaveAsPngAsync(outputFilename);

OnnxStack.Console/Examples/StableDiffusionBatch.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public async Task RunAsync()
6666
OutputHelpers.WriteConsole($"Image: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Cyan);
6767
};
6868

69-
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, default))
69+
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(new ModelOptions(model), promptOptions, schedulerOptions, batchOptions, default))
7070
{
7171
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
7272
var image = result.ImageResult.ToImage();

OnnxStack.Console/Examples/StableDiffusionExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public async Task RunAsync()
7070
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options)
7171
{
7272
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
73-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
73+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), prompt, options);
7474
if (result == null)
7575
return false;
7676

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public async Task RunAsync()
6767
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, string key)
6868
{
6969
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png");
70-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
70+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), prompt, options);
7171
if (result == null)
7272
return false;
7373

OnnxStack.Console/Examples/StableDiffusionGif.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public async Task RunAsync()
102102

103103
// Set prompt Image, Run Diffusion
104104
promptOptions.InputImage = new InputImage(mergedFrame.CloneAs<Rgba32>());
105-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, promptOptions, schedulerOptions);
105+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), promptOptions, schedulerOptions);
106106

107107
// Save Debug Output
108108
await result.SaveAsPngAsync(Path.Combine(_outputDirectory, $"Debug-Output.png"));

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ public enum OnnxModelType
99
TextEncoder2 = 21,
1010
VaeEncoder = 30,
1111
VaeDecoder = 40,
12+
ControlNet = 50,
13+
Annotation = 51,
1214
Upscaler = 1000
1315
}
1416
}

OnnxStack.Core/Extensions/OrtValueExtensions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ public static OrtValue ToOrtValue(this DenseTensor<long> tensor, OnnxNamedMetada
5959
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
6060
}
6161

62+
public static OrtValue ToOrtValue(this DenseTensor<double> tensor, OnnxNamedMetadata metadata)
63+
{
64+
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
65+
}
66+
6267

6368
/// <summary>
6469
/// Creates and allocates the output tensors buffer.

OnnxStack.Core/Image/Extensions.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ public static async Task<byte[]> ToImageBytesAsync(this DenseTensor<float> image
6363
}
6464
}
6565

66+
public static Image<Rgba32> ToImageMask(this DenseTensor<float> imageTensor)
67+
{
68+
var width = imageTensor.Dimensions[3];
69+
var height = imageTensor.Dimensions[2];
70+
using (var result = new Image<L8>(width, height))
71+
{
72+
for (var y = 0; y < height; y++)
73+
{
74+
for (var x = 0; x < width; x++)
75+
{
76+
result[x, y] = new L8((byte)(imageTensor[0, 0, y, x] * 255.0f));
77+
}
78+
}
79+
return result.CloneAs<Rgba32>();
80+
}
81+
}
82+
6683

6784
private static byte CalculateByte(Tensor<float> imageTensor, int index, int y, int x)
6885
{

OnnxStack.Core/Model/OnnxInferenceParameters.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ public void AddInputTensor(DenseTensor<float> value)
4545
_inputs.Add(metaData, value.ToOrtValue(metaData));
4646
}
4747

48+
public void AddInputTensor(DenseTensor<double> value)
49+
{
50+
var metaData = GetNextInputMetadata();
51+
_inputs.Add(metaData, value.ToOrtValue(metaData));
52+
}
53+
4854

4955
/// <summary>
5056
/// Adds the input tensor.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using OnnxStack.Core.Image;
2+
using OnnxStack.StableDiffusion.Config;
3+
using System.Threading.Tasks;
4+
5+
namespace OnnxStack.StableDiffusion.Common
6+
{
7+
public interface IControlNetImageService
8+
{
9+
10+
/// <summary>
11+
/// Prepares the ContolNet input image, If the ControlNetModelSet has a configure Annotation model this will be used to process the image
12+
/// </summary>
13+
/// <param name="controlNetModel">The control net model.</param>
14+
/// <param name="inputImage">The input image.</param>
15+
/// <param name="height">The height.</param>
16+
/// <param name="width">The width.</param>
17+
/// <returns></returns>
18+
Task<InputImage> PrepareInputImage(ControlNetModelSet controlNetModel, InputImage inputImage, int height, int width);
19+
}
20+
}

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core.Config;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Models;
45
using SixLabors.ImageSharp;
@@ -18,15 +19,15 @@ public interface IStableDiffusionService
1819
/// </summary>
1920
/// <param name="modelOptions">The model options.</param>
2021
/// <returns></returns>
21-
Task<bool> LoadModelAsync(StableDiffusionModelSet model);
22+
Task<bool> LoadModelAsync(IOnnxModelSetConfig model);
2223

2324

2425
/// <summary>
2526
/// Unloads the model.
2627
/// </summary>
2728
/// <param name="modelOptions">The model options.</param>
2829
/// <returns></returns>
29-
Task<bool> UnloadModelAsync(StableDiffusionModelSet model);
30+
Task<bool> UnloadModelAsync(IOnnxModel model);
3031

3132
/// <summary>
3233
/// Determines whether the specified model is loaded
@@ -35,7 +36,7 @@ public interface IStableDiffusionService
3536
/// <returns>
3637
/// <c>true</c> if the specified model is loaded; otherwise, <c>false</c>.
3738
/// </returns>
38-
bool IsModelLoaded(StableDiffusionModelSet model);
39+
bool IsModelLoaded(IOnnxModel model);
3940

4041
/// <summary>
4142
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -45,7 +46,7 @@ public interface IStableDiffusionService
4546
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
4647
/// <param name="cancellationToken">The cancellation token.</param>
4748
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
48-
Task<DenseTensor<float>> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
49+
Task<DenseTensor<float>> GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
4950

5051
/// <summary>
5152
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -55,7 +56,7 @@ public interface IStableDiffusionService
5556
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
5657
/// <param name="cancellationToken">The cancellation token.</param>
5758
/// <returns>The diffusion result as <see cref="SixLabors.ImageSharp.Image<Rgba32>"/></returns>
58-
Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
59+
Task<Image<Rgba32>> GenerateAsImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
5960

6061
/// <summary>
6162
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -65,7 +66,7 @@ public interface IStableDiffusionService
6566
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
6667
/// <param name="cancellationToken">The cancellation token.</param>
6768
/// <returns>The diffusion result as <see cref="byte[]"/></returns>
68-
Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
69+
Task<byte[]> GenerateAsBytesAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
6970

7071
/// <summary>
7172
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -75,7 +76,7 @@ public interface IStableDiffusionService
7576
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
7677
/// <param name="cancellationToken">The cancellation token.</param>
7778
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
78-
Task<Stream> GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
79+
Task<Stream> GenerateAsStreamAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
7980

8081
/// <summary>
8182
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -87,7 +88,7 @@ public interface IStableDiffusionService
8788
/// <param name="progressCallback">The progress callback.</param>
8889
/// <param name="cancellationToken">The cancellation token.</param>
8990
/// <returns></returns>
90-
IAsyncEnumerable<BatchResult> GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
91+
IAsyncEnumerable<BatchResult> GenerateBatchAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
9192

9293
/// <summary>
9394
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -99,7 +100,7 @@ public interface IStableDiffusionService
99100
/// <param name="progressCallback">The progress callback.</param>
100101
/// <param name="cancellationToken">The cancellation token.</param>
101102
/// <returns></returns>
102-
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
103+
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
103104

104105
/// <summary>
105106
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -111,7 +112,7 @@ public interface IStableDiffusionService
111112
/// <param name="progressCallback">The progress callback.</param>
112113
/// <param name="cancellationToken">The cancellation token.</param>
113114
/// <returns></returns>
114-
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
115+
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
115116

116117
/// <summary>
117118
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -123,6 +124,6 @@ public interface IStableDiffusionService
123124
/// <param name="progressCallback">The progress callback.</param>
124125
/// <param name="cancellationToken">The cancellation token.</param>
125126
/// <returns></returns>
126-
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
127+
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
127128
}
128129
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
3+
using OnnxStack.StableDiffusion.Enums;
4+
using System.Collections.Generic;
5+
6+
namespace OnnxStack.StableDiffusion.Config
7+
{
8+
public record ControlNetModelSet : IOnnxModelSetConfig
9+
{
10+
public ControlNetType Type { get; set; }
11+
public string Name { get; set; }
12+
public bool IsEnabled { get; set; }
13+
public int DeviceId { get; set; }
14+
public int InterOpNumThreads { get; set; }
15+
public int IntraOpNumThreads { get; set; }
16+
public ExecutionMode ExecutionMode { get; set; }
17+
public ExecutionProvider ExecutionProvider { get; set; }
18+
public List<OnnxModelConfig> ModelConfigurations { get; set; }
19+
}
20+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using OnnxStack.StableDiffusion.Enums;
2+
3+
namespace OnnxStack.StableDiffusion.Config
4+
{
5+
public record ModelOptions(StableDiffusionModelSet BaseModel, ControlNetModelSet ControlNetModel = default)
6+
{
7+
public string Name => BaseModel.Name;
8+
public DiffuserPipelineType PipelineType => BaseModel.PipelineType;
9+
public float ScaleFactor => BaseModel.ScaleFactor;
10+
public ModelType ModelType => BaseModel.ModelType;
11+
}
12+
}

OnnxStack.StableDiffusion/Config/PromptOptions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public class PromptOptions
2020

2121
public InputImage InputImageMask { get; set; }
2222

23+
public InputImage InputContolImage { get; set; }
24+
2325
public VideoInput InputVideo { get; set; }
2426

2527
public float VideoInputFPS { get; set; }

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ public record SchedulerOptions
8484
public float AestheticScore { get; set; } = 6f;
8585
public float AestheticNegativeScore { get; set; } = 2.5f;
8686

87+
public float ConditioningScale { get; set; } = 0.7f;
88+
public bool IsControlImageProcessingEnabled { get; set; }
89+
8790
public bool IsKarrasScheduler
8891
{
8992
get

0 commit comments

Comments
 (0)