Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 72473e8

Browse files
authored
Merge pull request #96 from saddam213/ControlNet
LCM ControlNet, SDXL ControlNet, InstaFlow ControlNet
2 parents eaec794 + 30d2cc9 commit 72473e8

File tree

9 files changed

+1093
-29
lines changed

9 files changed

+1093
-29
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Diagnostics;
15+
using System.Linq;
16+
using System.Threading;
17+
using System.Threading.Tasks;
18+
19+
namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
20+
{
21+
public class ControlNetDiffuser : InstaFlowDiffuser
22+
{
23+
private readonly IControlNetImageService _controlNetImageService;
24+
25+
/// <summary>
26+
/// Initializes a new instance of the <see cref="ControlNetDiffuser"/> class.
27+
/// </summary>
28+
/// <param name="configuration">The configuration.</param>
29+
/// <param name="onnxModelService">The onnx model service.</param>
30+
public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger<ControlNetDiffuser> logger)
31+
: base(onnxModelService, promptService, logger)
32+
{
33+
_controlNetImageService = controlNetImageService;
34+
}
35+
36+
/// <summary>
37+
/// Gets the type of the diffuser.
38+
/// </summary>
39+
public override DiffuserType DiffuserType => DiffuserType.ControlNet;
40+
41+
42+
/// <summary>
43+
/// Called on each Scheduler step.
44+
/// </summary>
45+
/// <param name="modelOptions">The model options.</param>
46+
/// <param name="promptOptions">The prompt options.</param>
47+
/// <param name="schedulerOptions">The scheduler options.</param>
48+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
49+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
50+
/// <param name="progressCallback">The progress callback.</param>
51+
/// <param name="cancellationToken">The cancellation token.</param>
52+
/// <returns></returns>
53+
/// <exception cref="NotImplementedException"></exception>
54+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
55+
{
56+
// Get Scheduler
57+
using (var scheduler = GetScheduler(schedulerOptions))
58+
{
59+
// Get timesteps
60+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
61+
62+
// Create latent sample
63+
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
64+
65+
// Get Model metadata
66+
var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet);
67+
68+
// Get Model metadata
69+
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet);
70+
71+
// Control Image
72+
var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions);
73+
74+
// Get the distilled Timestep
75+
var distilledTimestep = 1.0f / timesteps.Count;
76+
77+
// Loop though the timesteps
78+
var step = 0;
79+
foreach (var timestep in timesteps)
80+
{
81+
step++;
82+
var stepTime = Stopwatch.GetTimestamp();
83+
cancellationToken.ThrowIfCancellationRequested();
84+
85+
// Create input tensor.
86+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
87+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
88+
var timestepTensor = CreateTimestepTensor(timestep);
89+
var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage;
90+
var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);
91+
92+
var outputChannels = performGuidance ? 2 : 1;
93+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
94+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
95+
{
96+
inferenceParameters.AddInputTensor(inputTensor);
97+
inferenceParameters.AddInputTensor(timestepTensor);
98+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
99+
100+
// ControlNet
101+
using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
102+
{
103+
controlNetParameters.AddInputTensor(inputTensor);
104+
controlNetParameters.AddInputTensor(timestepTensor);
105+
controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
106+
controlNetParameters.AddInputTensor(controlImage);
107+
if (controlNetMetadata.Inputs.Count == 5)
108+
controlNetParameters.AddInputTensor(conditioningScale);
109+
110+
// Optimization: Pre-allocate device buffers for inputs
111+
foreach (var item in controlNetMetadata.Outputs)
112+
controlNetParameters.AddOutputBuffer();
113+
114+
// ControlNet inference
115+
var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters);
116+
117+
// Add ControlNet outputs to Unet input
118+
foreach (var item in controlNetResults)
119+
inferenceParameters.AddInput(item);
120+
121+
// Add output buffer
122+
inferenceParameters.AddOutputBuffer(outputDimension);
123+
124+
// Unet inference
125+
var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters);
126+
using (var result = results.First())
127+
{
128+
var noisePred = result.ToDenseTensor();
129+
130+
// Perform guidance
131+
if (performGuidance)
132+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
133+
134+
// Scheduler Step
135+
latents = scheduler.Step(noisePred, timestep, latents).Result;
136+
137+
latents = noisePred
138+
.MultiplyTensorByFloat(distilledTimestep)
139+
.AddTensors(latents);
140+
}
141+
}
142+
}
143+
144+
ReportProgress(progressCallback, step, timesteps.Count, latents);
145+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
146+
}
147+
148+
// Decode Latents
149+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
150+
}
151+
}
152+
153+
154+
/// <summary>
155+
/// Gets the timesteps.
156+
/// </summary>
157+
/// <param name="options">The options.</param>
158+
/// <param name="scheduler">The scheduler.</param>
159+
/// <returns></returns>
160+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
161+
{
162+
return scheduler.Timesteps;
163+
}
164+
165+
166+
/// <summary>
167+
/// Prepares the input latents.
168+
/// </summary>
169+
/// <param name="model">The model.</param>
170+
/// <param name="prompt">The prompt.</param>
171+
/// <param name="options">The options.</param>
172+
/// <param name="scheduler">The scheduler.</param>
173+
/// <param name="timesteps">The timesteps.</param>
174+
/// <returns></returns>
175+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
176+
{
177+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
178+
}
179+
180+
181+
/// <summary>
182+
/// Creates the Conditioning Scale tensor.
183+
/// </summary>
184+
/// <param name="conditioningScale">The conditioningScale.</param>
185+
/// <returns></returns>
186+
protected static DenseTensor<double> CreateConditioningScaleTensor(float conditioningScale)
187+
{
188+
return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
189+
}
190+
191+
192+
/// <summary>
193+
/// Prepares the control image.
194+
/// </summary>
195+
/// <param name="promptOptions">The prompt options.</param>
196+
/// <param name="schedulerOptions">The scheduler options.</param>
197+
/// <returns></returns>
198+
protected async Task<DenseTensor<float>> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
199+
{
200+
var controlImage = promptOptions.InputContolImage;
201+
if (schedulerOptions.IsControlImageProcessingEnabled)
202+
{
203+
controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width);
204+
}
205+
return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
206+
}
207+
}
208+
}

0 commit comments

Comments
 (0)