diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
new file mode 100644
index 00000000..547f55fd
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
@@ -0,0 +1,208 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using OnnxStack.StableDiffusion.Models;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
+{
+ public class ControlNetDiffuser : InstaFlowDiffuser
+ {
+ private readonly IControlNetImageService _controlNetImageService;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, logger)
+ {
+ _controlNetImageService = controlNetImageService;
+ }
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNet;
+
+
+ ///
+ /// Called on each Scheduler step.
+ ///
+ /// The model options.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The prompt embeddings.
+ /// if set to true [perform guidance].
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ ///
+ protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ // Get Scheduler
+ using (var scheduler = GetScheduler(schedulerOptions))
+ {
+ // Get timesteps
+ var timesteps = GetTimesteps(schedulerOptions, scheduler);
+
+ // Create latent sample
+ var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
+
+ // Get Model metadata
+ var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet);
+
+ // Get Model metadata
+ var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet);
+
+ // Control Image
+ var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions);
+
+ // Get the distilled Timestep
+ var distilledTimestep = 1.0f / timesteps.Count;
+
+ // Loop though the timesteps
+ var step = 0;
+ foreach (var timestep in timesteps)
+ {
+ step++;
+ var stepTime = Stopwatch.GetTimestamp();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Create input tensor.
+ var inputLatent = performGuidance ? latents.Repeat(2) : latents;
+ var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
+ var timestepTensor = CreateTimestepTensor(timestep);
+ var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage;
+ var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);
+
+ var outputChannels = performGuidance ? 2 : 1;
+ var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(inputTensor);
+ inferenceParameters.AddInputTensor(timestepTensor);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+
+ // ControlNet
+ using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
+ {
+ controlNetParameters.AddInputTensor(inputTensor);
+ controlNetParameters.AddInputTensor(timestepTensor);
+ controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ controlNetParameters.AddInputTensor(controlImage);
+ if (controlNetMetadata.Inputs.Count == 5)
+ controlNetParameters.AddInputTensor(conditioningScale);
+
+ // Optimization: Pre-allocate device buffers for inputs
+ foreach (var item in controlNetMetadata.Outputs)
+ controlNetParameters.AddOutputBuffer();
+
+ // ControlNet inference
+ var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters);
+
+ // Add ControlNet outputs to Unet input
+ foreach (var item in controlNetResults)
+ inferenceParameters.AddInput(item);
+
+ // Add output buffer
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ // Unet inference
+ var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters);
+ using (var result = results.First())
+ {
+ var noisePred = result.ToDenseTensor();
+
+ // Perform guidance
+ if (performGuidance)
+ noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
+
+ // Scheduler Step
+ latents = scheduler.Step(noisePred, timestep, latents).Result;
+
+ latents = noisePred
+ .MultiplyTensorByFloat(distilledTimestep)
+ .AddTensors(latents);
+ }
+ }
+ }
+
+ ReportProgress(progressCallback, step, timesteps.Count, latents);
+ _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ }
+
+ // Decode Latents
+ return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
+ }
+ }
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ return scheduler.Timesteps;
+ }
+
+
+ ///
+ /// Prepares the input latents.
+ ///
+ /// The model.
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ /// The timesteps.
+ ///
+ protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
+ }
+
+
+ ///
+ /// Creates the Conditioning Scale tensor.
+ ///
+ /// The conditioningScale.
+ ///
+ protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale)
+ {
+ return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
+ }
+
+
+ ///
+ /// Prepares the control image.
+ ///
+ /// The prompt options.
+ /// The scheduler options.
+ ///
+ protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
+ {
+ var controlImage = promptOptions.InputContolImage;
+ if (schedulerOptions.IsControlImageProcessingEnabled)
+ {
+ controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width);
+ }
+ return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
+ }
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs
new file mode 100644
index 00000000..c9cab1a0
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs
@@ -0,0 +1,206 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using OnnxStack.StableDiffusion.Models;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
+{
+ public class ControlNetDiffuser : LatentConsistencyDiffuser
+ {
+ private readonly IControlNetImageService _controlNetImageService;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, logger)
+ {
+ _controlNetImageService = controlNetImageService;
+ }
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNet;
+
+
+ ///
+ /// Called on each Scheduler step.
+ ///
+ /// The model options.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The prompt embeddings.
+ /// if set to true [perform guidance].
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ ///
+ protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ // Get Scheduler
+ using (var scheduler = GetScheduler(schedulerOptions))
+ {
+ // Get timesteps
+ var timesteps = GetTimesteps(schedulerOptions, scheduler);
+
+ // Create latent sample
+ var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
+
+ // Get Guidance Scale Embedding
+ var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);
+
+ // Get Model metadata
+ var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet);
+
+ // Get Model metadata
+ var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet);
+
+ // Control Image
+ var controlImageTensor = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions);
+
+ // Denoised result
+ DenseTensor denoised = null;
+
+ // Loop though the timesteps
+ var step = 0;
+ foreach (var timestep in timesteps)
+ {
+ step++;
+ var stepTime = Stopwatch.GetTimestamp();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Create input tensor.
+ var inputTensor = scheduler.ScaleInput(latents, timestep);
+ var timestepTensor = CreateTimestepTensor(timestep);
+ var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);
+
+ var batchCount = 1;
+ var outputDimension = schedulerOptions.GetScaledDimension(batchCount);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(inputTensor);
+ inferenceParameters.AddInputTensor(timestepTensor);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ inferenceParameters.AddInputTensor(guidanceEmbeddings);
+
+ // ControlNet
+ using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
+ {
+ controlNetParameters.AddInputTensor(inputTensor);
+ controlNetParameters.AddInputTensor(timestepTensor);
+ controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ controlNetParameters.AddInputTensor(guidanceEmbeddings);
+ controlNetParameters.AddInputTensor(controlImageTensor);
+ if (controlNetMetadata.Inputs.Count == 5)
+ controlNetParameters.AddInputTensor(conditioningScale);
+
+ // Optimization: Pre-allocate device buffers for inputs
+ foreach (var item in controlNetMetadata.Outputs)
+ controlNetParameters.AddOutputBuffer();
+
+ // ControlNet inference
+ var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters);
+
+ // Add ControlNet outputs to Unet input
+ foreach (var item in controlNetResults)
+ inferenceParameters.AddInput(item);
+
+ // Add output buffer
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ // Unet inference
+ var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters);
+ using (var result = results.First())
+ {
+ var noisePred = result.ToDenseTensor();
+
+ // Scheduler Step
+ var schedulerResult = scheduler.Step(noisePred, timestep, latents);
+
+ latents = schedulerResult.Result;
+ denoised = schedulerResult.SampleData;
+ }
+ }
+ }
+
+ ReportProgress(progressCallback, step, timesteps.Count, latents);
+ _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ }
+
+ // Decode Latents
+ return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
+ }
+ }
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ return scheduler.Timesteps;
+ }
+
+
+ ///
+ /// Prepares the input latents.
+ ///
+ /// The model.
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ /// The timesteps.
+ ///
+ protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
+ }
+
+
+ ///
+ /// Creates the Conditioning Scale tensor.
+ ///
+ /// The conditioningScale.
+ ///
+ protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale)
+ {
+ return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
+ }
+
+
+ ///
+ /// Prepares the control image.
+ ///
+ /// The prompt options.
+ /// The scheduler options.
+ ///
+ protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
+ {
+ var controlImage = promptOptions.InputContolImage;
+ if (schedulerOptions.IsControlImageProcessingEnabled)
+ {
+ controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width);
+ }
+ return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
+ }
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs
new file mode 100644
index 00000000..e535989d
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs
@@ -0,0 +1,83 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using SixLabors.ImageSharp;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
+{
+ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, controlNetImageService, logger)
+ {
+ }
+
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNetImage;
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
+ var start = Math.Max(options.InferenceSteps - inittimestep, 0);
+ return scheduler.Timesteps.Skip(start).ToList();
+ }
+
+
+ ///
+ /// Prepares the latents for inference.
+ ///
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override async Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
+
+ //TODO: Model Config, Channels
+ var outputDimension = options.GetScaledDimension();
+ var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeEncoder);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(imageTensor);
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeEncoder, inferenceParameters);
+ using (var result = results.First())
+ {
+ var outputResult = result.ToDenseTensor();
+ var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
+ return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
+ }
+ }
+ }
+
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs
new file mode 100644
index 00000000..e1f265f3
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs
@@ -0,0 +1,211 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using OnnxStack.StableDiffusion.Models;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
+{
+ public class ControlNetDiffuser : LatentConsistencyXLDiffuser
+ {
+ private readonly IControlNetImageService _controlNetImageService;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, logger)
+ {
+ _controlNetImageService = controlNetImageService;
+ }
+
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNet;
+
+
+ ///
+ /// Called on each Scheduler step.
+ ///
+ /// The model options.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The prompt embeddings.
+ /// if set to true [perform guidance].
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ ///
+ protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ // Get Scheduler
+ using (var scheduler = GetScheduler(schedulerOptions))
+ {
+ // Get timesteps
+ var timesteps = GetTimesteps(schedulerOptions, scheduler);
+
+ // Create latent sample
+ var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
+
+ // Get Model metadata
+ var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet);
+
+ // Get Time ids
+ var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions);
+
+ // Get Model metadata
+ var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet);
+
+ // Control Image
+ var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions);
+
+ // Loop though the timesteps
+ var step = 0;
+ foreach (var timestep in timesteps)
+ {
+ step++;
+ var stepTime = Stopwatch.GetTimestamp();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Create input tensor.
+ var inputLatent = performGuidance ? latents.Repeat(2) : latents;
+ var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
+ var timestepTensor = CreateTimestepTensor(timestep);
+ var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
+ var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage;
+ var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);
+
+ var batchCount = performGuidance ? 2 : 1;
+ var outputDimension = schedulerOptions.GetScaledDimension(batchCount);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(inputTensor);
+ inferenceParameters.AddInputTensor(timestepTensor);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
+ inferenceParameters.AddInputTensor(timeids);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+
+ // ControlNet
+ using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
+ {
+ controlNetParameters.AddInputTensor(inputTensor);
+ controlNetParameters.AddInputTensor(timestepTensor);
+ controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ controlNetParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
+ controlNetParameters.AddInputTensor(timeids);
+ controlNetParameters.AddInputTensor(controlImage);
+ if (controlNetMetadata.Inputs.Count == 5)
+ controlNetParameters.AddInputTensor(conditioningScale);
+
+ // Optimization: Pre-allocate device buffers for inputs
+ foreach (var item in controlNetMetadata.Outputs)
+ controlNetParameters.AddOutputBuffer();
+
+ // ControlNet inference
+ var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters);
+
+ // Add ControlNet outputs to Unet input
+ foreach (var item in controlNetResults)
+ inferenceParameters.AddInput(item);
+
+ // Add output buffer
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ // Unet inference
+ var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters);
+ using (var result = results.First())
+ {
+ var noisePred = result.ToDenseTensor();
+
+ // Perform guidance
+ if (performGuidance)
+ noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
+
+ // Scheduler Step
+ latents = scheduler.Step(noisePred, timestep, latents).Result;
+ }
+ }
+ }
+
+ ReportProgress(progressCallback, step, timesteps.Count, latents);
+ _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ }
+
+ // Decode Latents
+ return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
+ }
+ }
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ return scheduler.Timesteps;
+ }
+
+
+ ///
+ /// Prepares the input latents.
+ ///
+ /// The model.
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ /// The timesteps.
+ ///
+ protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
+ }
+
+
+ ///
+ /// Creates the Conditioning Scale tensor.
+ ///
+ /// The conditioningScale.
+ ///
+ protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale)
+ {
+ return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
+ }
+
+
+ ///
+ /// Prepares the control image.
+ ///
+ /// The prompt options.
+ /// The scheduler options.
+ ///
+ protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
+ {
+ var controlImage = promptOptions.InputContolImage;
+ if (schedulerOptions.IsControlImageProcessingEnabled)
+ {
+ controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width);
+ }
+ return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
+ }
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs
new file mode 100644
index 00000000..3dd18a34
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs
@@ -0,0 +1,83 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using SixLabors.ImageSharp;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
+{
+ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, controlNetImageService, logger)
+ {
+ }
+
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNetImage;
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
+ var start = Math.Max(options.InferenceSteps - inittimestep, 0);
+ return scheduler.Timesteps.Skip(start).ToList();
+ }
+
+
+ ///
+ /// Prepares the latents for inference.
+ ///
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override async Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
+
+ //TODO: Model Config, Channels
+ var outputDimension = options.GetScaledDimension();
+ var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeEncoder);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(imageTensor);
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeEncoder, inferenceParameters);
+ using (var result = results.First())
+ {
+ var outputResult = result.ToDenseTensor();
+ var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
+ return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
+ }
+ }
+ }
+
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs
index 189d0b83..4b97b7e2 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs
@@ -9,7 +9,6 @@
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Helpers;
using OnnxStack.StableDiffusion.Models;
-using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
using System;
using System.Collections.Generic;
using System.Diagnostics;
@@ -19,7 +18,7 @@
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
{
- public class ControlNetDiffuser : DiffuserBase
+ public class ControlNetDiffuser : StableDiffusionDiffuser
{
private readonly IControlNetImageService _controlNetImageService;
@@ -35,12 +34,6 @@ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService pro
}
- ///
- /// Gets the type of the pipeline.
- ///
- public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableDiffusion;
-
-
///
/// Gets the type of the diffuser.
///
@@ -205,26 +198,5 @@ protected async Task> PrepareControlImage(ModelOptions modelO
}
return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
}
-
-
- ///
- /// Gets the scheduler.
- ///
- /// The options.
- /// The scheduler configuration.
- ///
- protected override IScheduler GetScheduler(SchedulerOptions options)
- {
- return options.SchedulerType switch
- {
- SchedulerType.LMS => new LMSScheduler(options),
- SchedulerType.Euler => new EulerScheduler(options),
- SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
- SchedulerType.DDPM => new DDPMScheduler(options),
- SchedulerType.DDIM => new DDIMScheduler(options),
- SchedulerType.KDPM2 => new KDPM2Scheduler(options),
- _ => default
- };
- }
}
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs
new file mode 100644
index 00000000..d0d921b1
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs
@@ -0,0 +1,211 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using OnnxStack.StableDiffusion.Models;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL
+{
+ public class ControlNetDiffuser : StableDiffusionXLDiffuser
+ {
+ private readonly IControlNetImageService _controlNetImageService;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, logger)
+ {
+ _controlNetImageService = controlNetImageService;
+ }
+
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNet;
+
+
+ ///
+ /// Called on each Scheduler step.
+ ///
+ /// The model options.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The prompt embeddings.
+ /// if set to true [perform guidance].
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ ///
+ protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ // Get Scheduler
+ using (var scheduler = GetScheduler(schedulerOptions))
+ {
+ // Get timesteps
+ var timesteps = GetTimesteps(schedulerOptions, scheduler);
+
+ // Create latent sample
+ var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
+
+ // Get Model metadata
+ var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet);
+
+ // Get Time ids
+ var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions);
+
+ // Get Model metadata
+ var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet);
+
+ // Control Image
+ var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions);
+
+ // Loop though the timesteps
+ var step = 0;
+ foreach (var timestep in timesteps)
+ {
+ step++;
+ var stepTime = Stopwatch.GetTimestamp();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Create input tensor.
+ var inputLatent = performGuidance ? latents.Repeat(2) : latents;
+ var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
+ var timestepTensor = CreateTimestepTensor(timestep);
+ var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
+ var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage;
+ var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);
+
+ var batchCount = performGuidance ? 2 : 1;
+ var outputDimension = schedulerOptions.GetScaledDimension(batchCount);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(inputTensor);
+ inferenceParameters.AddInputTensor(timestepTensor);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
+ inferenceParameters.AddInputTensor(timeids);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+
+ // ControlNet
+ using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
+ {
+ controlNetParameters.AddInputTensor(inputTensor);
+ controlNetParameters.AddInputTensor(timestepTensor);
+ controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ controlNetParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
+ controlNetParameters.AddInputTensor(timeids);
+ controlNetParameters.AddInputTensor(controlImage);
+ if (controlNetMetadata.Inputs.Count == 5)
+ controlNetParameters.AddInputTensor(conditioningScale);
+
+ // Optimization: Pre-allocate device buffers for inputs
+ foreach (var item in controlNetMetadata.Outputs)
+ controlNetParameters.AddOutputBuffer();
+
+ // ControlNet inference
+ var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters);
+
+ // Add ControlNet outputs to Unet input
+ foreach (var item in controlNetResults)
+ inferenceParameters.AddInput(item);
+
+ // Add output buffer
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ // Unet inference
+ var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters);
+ using (var result = results.First())
+ {
+ var noisePred = result.ToDenseTensor();
+
+ // Perform guidance
+ if (performGuidance)
+ noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
+
+ // Scheduler Step
+ latents = scheduler.Step(noisePred, timestep, latents).Result;
+ }
+ }
+ }
+
+ ReportProgress(progressCallback, step, timesteps.Count, latents);
+ _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
+ }
+
+ // Decode Latents
+ return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
+ }
+ }
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ return scheduler.Timesteps;
+ }
+
+
+ ///
+ /// Prepares the input latents.
+ ///
+ /// The model.
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ /// The timesteps.
+ ///
+ protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
+ }
+
+
+ ///
+ /// Creates the Conditioning Scale tensor.
+ ///
+ /// The conditioningScale.
+ ///
+ protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale)
+ {
+ return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
+ }
+
+
+ ///
+ /// Prepares the control image.
+ ///
+ /// The prompt options.
+ /// The scheduler options.
+ ///
+ protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
+ {
+ var controlImage = promptOptions.InputContolImage;
+ if (schedulerOptions.IsControlImageProcessingEnabled)
+ {
+ controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width);
+ }
+ return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
+ }
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs
new file mode 100644
index 00000000..e85dc7bf
--- /dev/null
+++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs
@@ -0,0 +1,83 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Model;
+using OnnxStack.Core.Services;
+using OnnxStack.StableDiffusion.Common;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Helpers;
+using SixLabors.ImageSharp;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL
+{
+ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The configuration.
+ /// The onnx model service.
+ public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger)
+ : base(onnxModelService, promptService, controlNetImageService, logger)
+ {
+ }
+
+
+ ///
+ /// Gets the type of the diffuser.
+ ///
+ public override DiffuserType DiffuserType => DiffuserType.ControlNetImage;
+
+
+ ///
+ /// Gets the timesteps.
+ ///
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler)
+ {
+ var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
+ var start = Math.Max(options.InferenceSteps - inittimestep, 0);
+ return scheduler.Timesteps.Skip(start).ToList();
+ }
+
+
+ ///
+ /// Prepares the latents for inference.
+ ///
+ /// The prompt.
+ /// The options.
+ /// The scheduler.
+ ///
+ protected override async Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps)
+ {
+ var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
+
+ //TODO: Model Config, Channels
+ var outputDimension = options.GetScaledDimension();
+ var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeEncoder);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(imageTensor);
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeEncoder, inferenceParameters);
+ using (var result = results.First())
+ {
+ var outputResult = result.ToDenseTensor();
+ var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
+ return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
+ }
+ }
+ }
+
+ }
+}
diff --git a/OnnxStack.StableDiffusion/Registration.cs b/OnnxStack.StableDiffusion/Registration.cs
index ad88d175..e9003d5f 100644
--- a/OnnxStack.StableDiffusion/Registration.cs
+++ b/OnnxStack.StableDiffusion/Registration.cs
@@ -69,19 +69,26 @@ private static void RegisterServices(this IServiceCollection serviceCollection)
serviceCollection.AddSingleton();
serviceCollection.AddSingleton();
serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
//LatentConsistency
serviceCollection.AddSingleton();
serviceCollection.AddSingleton();
serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
//LatentConsistencyXL
serviceCollection.AddSingleton();
serviceCollection.AddSingleton();
serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
//InstaFlow
serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
}