diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs new file mode 100644 index 00000000..547f55fd --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs @@ -0,0 +1,208 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow +{ + public class ControlNetDiffuser : InstaFlowDiffuser + { + private readonly IControlNetImageService _controlNetImageService; + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, logger) + { + _controlNetImageService = controlNetImageService; + } + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNet; + + + /// + /// Called on each Scheduler step. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + /// + protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + // Get Scheduler + using (var scheduler = GetScheduler(schedulerOptions)) + { + // Get timesteps + var timesteps = GetTimesteps(schedulerOptions, scheduler); + + // Create latent sample + var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps); + + // Get Model metadata + var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet); + + // Get Model metadata + var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet); + + // Control Image + var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions); + + // Get the distilled Timestep + var distilledTimestep = 1.0f / timesteps.Count; + + // Loop though the timesteps + var step = 0; + foreach (var timestep in timesteps) + { + step++; + var stepTime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Create input tensor. + var inputLatent = performGuidance ? latents.Repeat(2) : latents; + var inputTensor = scheduler.ScaleInput(inputLatent, timestep); + var timestepTensor = CreateTimestepTensor(timestep); + var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage; + var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale); + + var outputChannels = performGuidance ? 2 : 1; + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddInputTensor(timestepTensor); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + + // ControlNet + using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata)) + { + controlNetParameters.AddInputTensor(inputTensor); + controlNetParameters.AddInputTensor(timestepTensor); + controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + controlNetParameters.AddInputTensor(controlImage); + if (controlNetMetadata.Inputs.Count == 5) + controlNetParameters.AddInputTensor(conditioningScale); + + // Optimization: Pre-allocate device buffers for inputs + foreach (var item in controlNetMetadata.Outputs) + controlNetParameters.AddOutputBuffer(); + + // ControlNet inference + var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters); + + // Add ControlNet outputs to Unet input + foreach (var item in controlNetResults) + inferenceParameters.AddInput(item); + + // Add output buffer + inferenceParameters.AddOutputBuffer(outputDimension); + + // Unet inference + var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters); + using (var result = results.First()) + { + var noisePred = result.ToDenseTensor(); + + // Perform guidance + if (performGuidance) + noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale); + + // Scheduler Step + latents = scheduler.Step(noisePred, timestep, latents).Result; + + latents = noisePred + .MultiplyTensorByFloat(distilledTimestep) + .AddTensors(latents); + } + } + } + + ReportProgress(progressCallback, step, timesteps.Count, latents); + _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); + } + + // Decode Latents + return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents); + } + } + + + /// + /// Gets the timesteps. + /// + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + return scheduler.Timesteps; + } + + + /// + /// Prepares the input latents. + /// + /// The model. + /// The prompt. + /// The options. + /// The scheduler. + /// The timesteps. + /// + protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); + } + + + /// + /// Creates the Conditioning Scale tensor. + /// + /// The conditioningScale. + /// + protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale) + { + return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 }); + } + + + /// + /// Prepares the control image. + /// + /// The prompt options. + /// The scheduler options. + /// + protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + { + var controlImage = promptOptions.InputContolImage; + if (schedulerOptions.IsControlImageProcessingEnabled) + { + controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width); + } + return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false); + } + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs new file mode 100644 index 00000000..c9cab1a0 --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs @@ -0,0 +1,206 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency +{ + public class ControlNetDiffuser : LatentConsistencyDiffuser + { + private readonly IControlNetImageService _controlNetImageService; + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, logger) + { + _controlNetImageService = controlNetImageService; + } + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNet; + + + /// + /// Called on each Scheduler step. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + /// + protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + // Get Scheduler + using (var scheduler = GetScheduler(schedulerOptions)) + { + // Get timesteps + var timesteps = GetTimesteps(schedulerOptions, scheduler); + + // Create latent sample + var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps); + + // Get Guidance Scale Embedding + var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale); + + // Get Model metadata + var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet); + + // Get Model metadata + var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet); + + // Control Image + var controlImageTensor = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions); + + // Denoised result + DenseTensor denoised = null; + + // Loop though the timesteps + var step = 0; + foreach (var timestep in timesteps) + { + step++; + var stepTime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Create input tensor. + var inputTensor = scheduler.ScaleInput(latents, timestep); + var timestepTensor = CreateTimestepTensor(timestep); + var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale); + + var batchCount = 1; + var outputDimension = schedulerOptions.GetScaledDimension(batchCount); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddInputTensor(timestepTensor); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + inferenceParameters.AddInputTensor(guidanceEmbeddings); + + // ControlNet + using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata)) + { + controlNetParameters.AddInputTensor(inputTensor); + controlNetParameters.AddInputTensor(timestepTensor); + controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + controlNetParameters.AddInputTensor(guidanceEmbeddings); + controlNetParameters.AddInputTensor(controlImageTensor); + if (controlNetMetadata.Inputs.Count == 5) + controlNetParameters.AddInputTensor(conditioningScale); + + // Optimization: Pre-allocate device buffers for inputs + foreach (var item in controlNetMetadata.Outputs) + controlNetParameters.AddOutputBuffer(); + + // ControlNet inference + var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters); + + // Add ControlNet outputs to Unet input + foreach (var item in controlNetResults) + inferenceParameters.AddInput(item); + + // Add output buffer + inferenceParameters.AddOutputBuffer(outputDimension); + + // Unet inference + var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters); + using (var result = results.First()) + { + var noisePred = result.ToDenseTensor(); + + // Scheduler Step + var schedulerResult = scheduler.Step(noisePred, timestep, latents); + + latents = schedulerResult.Result; + denoised = schedulerResult.SampleData; + } + } + } + + ReportProgress(progressCallback, step, timesteps.Count, latents); + _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); + } + + // Decode Latents + return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents); + } + } + + + /// + /// Gets the timesteps. + /// + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + return scheduler.Timesteps; + } + + + /// + /// Prepares the input latents. + /// + /// The model. + /// The prompt. + /// The options. + /// The scheduler. + /// The timesteps. + /// + protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); + } + + + /// + /// Creates the Conditioning Scale tensor. + /// + /// The conditioningScale. + /// + protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale) + { + return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 }); + } + + + /// + /// Prepares the control image. + /// + /// The prompt options. + /// The scheduler options. + /// + protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + { + var controlImage = promptOptions.InputContolImage; + if (schedulerOptions.IsControlImageProcessingEnabled) + { + controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width); + } + return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false); + } + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs new file mode 100644 index 00000000..e535989d --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs @@ -0,0 +1,83 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using SixLabors.ImageSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency +{ + public sealed class ControlNetImageDiffuser : ControlNetDiffuser + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, controlNetImageService, logger) + { + } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNetImage; + + + /// + /// Gets the timesteps. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps); + var start = Math.Max(options.InferenceSteps - inittimestep, 0); + return scheduler.Timesteps.Skip(start).ToList(); + } + + + /// + /// Prepares the latents for inference. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override async Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); + + //TODO: Model Config, Channels + var outputDimension = options.GetScaledDimension(); + var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeEncoder); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(imageTensor); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeEncoder, inferenceParameters); + using (var result = results.First()) + { + var outputResult = result.ToDenseTensor(); + var scaledSample = outputResult.MultiplyBy(model.ScaleFactor); + return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); + } + } + } + + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs new file mode 100644 index 00000000..e1f265f3 --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetDiffuser.cs @@ -0,0 +1,211 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL +{ + public class ControlNetDiffuser : LatentConsistencyXLDiffuser + { + private readonly IControlNetImageService _controlNetImageService; + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, logger) + { + _controlNetImageService = controlNetImageService; + } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNet; + + + /// + /// Called on each Scheduler step. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + /// + protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + // Get Scheduler + using (var scheduler = GetScheduler(schedulerOptions)) + { + // Get timesteps + var timesteps = GetTimesteps(schedulerOptions, scheduler); + + // Create latent sample + var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps); + + // Get Model metadata + var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet); + + // Get Time ids + var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions); + + // Get Model metadata + var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet); + + // Control Image + var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions); + + // Loop though the timesteps + var step = 0; + foreach (var timestep in timesteps) + { + step++; + var stepTime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Create input tensor. + var inputLatent = performGuidance ? latents.Repeat(2) : latents; + var inputTensor = scheduler.ScaleInput(inputLatent, timestep); + var timestepTensor = CreateTimestepTensor(timestep); + var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds; + var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage; + var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale); + + var batchCount = performGuidance ? 2 : 1; + var outputDimension = schedulerOptions.GetScaledDimension(batchCount); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddInputTensor(timestepTensor); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds); + inferenceParameters.AddInputTensor(timeids); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + + // ControlNet + using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata)) + { + controlNetParameters.AddInputTensor(inputTensor); + controlNetParameters.AddInputTensor(timestepTensor); + controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + controlNetParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds); + controlNetParameters.AddInputTensor(timeids); + controlNetParameters.AddInputTensor(controlImage); + if (controlNetMetadata.Inputs.Count == 5) + controlNetParameters.AddInputTensor(conditioningScale); + + // Optimization: Pre-allocate device buffers for inputs + foreach (var item in controlNetMetadata.Outputs) + controlNetParameters.AddOutputBuffer(); + + // ControlNet inference + var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters); + + // Add ControlNet outputs to Unet input + foreach (var item in controlNetResults) + inferenceParameters.AddInput(item); + + // Add output buffer + inferenceParameters.AddOutputBuffer(outputDimension); + + // Unet inference + var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters); + using (var result = results.First()) + { + var noisePred = result.ToDenseTensor(); + + // Perform guidance + if (performGuidance) + noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale); + + // Scheduler Step + latents = scheduler.Step(noisePred, timestep, latents).Result; + } + } + } + + ReportProgress(progressCallback, step, timesteps.Count, latents); + _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); + } + + // Decode Latents + return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents); + } + } + + + /// + /// Gets the timesteps. + /// + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + return scheduler.Timesteps; + } + + + /// + /// Prepares the input latents. + /// + /// The model. + /// The prompt. + /// The options. + /// The scheduler. + /// The timesteps. + /// + protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); + } + + + /// + /// Creates the Conditioning Scale tensor. + /// + /// The conditioningScale. + /// + protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale) + { + return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 }); + } + + + /// + /// Prepares the control image. + /// + /// The prompt options. + /// The scheduler options. + /// + protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + { + var controlImage = promptOptions.InputContolImage; + if (schedulerOptions.IsControlImageProcessingEnabled) + { + controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width); + } + return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false); + } + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs new file mode 100644 index 00000000..3dd18a34 --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/ControlNetImageDiffuser.cs @@ -0,0 +1,83 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using SixLabors.ImageSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL +{ + public sealed class ControlNetImageDiffuser : ControlNetDiffuser + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, controlNetImageService, logger) + { + } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNetImage; + + + /// + /// Gets the timesteps. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps); + var start = Math.Max(options.InferenceSteps - inittimestep, 0); + return scheduler.Timesteps.Skip(start).ToList(); + } + + + /// + /// Prepares the latents for inference. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override async Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); + + //TODO: Model Config, Channels + var outputDimension = options.GetScaledDimension(); + var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeEncoder); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(imageTensor); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeEncoder, inferenceParameters); + using (var result = results.First()) + { + var outputResult = result.ToDenseTensor(); + var scaledSample = outputResult.MultiplyBy(model.ScaleFactor); + return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); + } + } + } + + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs index 189d0b83..4b97b7e2 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs @@ -9,7 +9,6 @@ using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; using OnnxStack.StableDiffusion.Models; -using OnnxStack.StableDiffusion.Schedulers.StableDiffusion; using System; using System.Collections.Generic; using System.Diagnostics; @@ -19,7 +18,7 @@ namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion { - public class ControlNetDiffuser : DiffuserBase + public class ControlNetDiffuser : StableDiffusionDiffuser { private readonly IControlNetImageService _controlNetImageService; @@ -35,12 +34,6 @@ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService pro } - /// - /// Gets the type of the pipeline. - /// - public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableDiffusion; - - /// /// Gets the type of the diffuser. /// @@ -205,26 +198,5 @@ protected async Task> PrepareControlImage(ModelOptions modelO } return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false); } - - - /// - /// Gets the scheduler. - /// - /// The options. - /// The scheduler configuration. - /// - protected override IScheduler GetScheduler(SchedulerOptions options) - { - return options.SchedulerType switch - { - SchedulerType.LMS => new LMSScheduler(options), - SchedulerType.Euler => new EulerScheduler(options), - SchedulerType.EulerAncestral => new EulerAncestralScheduler(options), - SchedulerType.DDPM => new DDPMScheduler(options), - SchedulerType.DDIM => new DDIMScheduler(options), - SchedulerType.KDPM2 => new KDPM2Scheduler(options), - _ => default - }; - } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs new file mode 100644 index 00000000..d0d921b1 --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetDiffuser.cs @@ -0,0 +1,211 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL +{ + public class ControlNetDiffuser : StableDiffusionXLDiffuser + { + private readonly IControlNetImageService _controlNetImageService; + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, logger) + { + _controlNetImageService = controlNetImageService; + } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNet; + + + /// + /// Called on each Scheduler step. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + /// + protected override async Task> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + // Get Scheduler + using (var scheduler = GetScheduler(schedulerOptions)) + { + // Get timesteps + var timesteps = GetTimesteps(schedulerOptions, scheduler); + + // Create latent sample + var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps); + + // Get Model metadata + var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet); + + // Get Time ids + var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions); + + // Get Model metadata + var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet); + + // Control Image + var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions); + + // Loop though the timesteps + var step = 0; + foreach (var timestep in timesteps) + { + step++; + var stepTime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Create input tensor. + var inputLatent = performGuidance ? latents.Repeat(2) : latents; + var inputTensor = scheduler.ScaleInput(inputLatent, timestep); + var timestepTensor = CreateTimestepTensor(timestep); + var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds; + var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage; + var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale); + + var batchCount = performGuidance ? 2 : 1; + var outputDimension = schedulerOptions.GetScaledDimension(batchCount); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddInputTensor(timestepTensor); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds); + inferenceParameters.AddInputTensor(timeids); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + + // ControlNet + using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata)) + { + controlNetParameters.AddInputTensor(inputTensor); + controlNetParameters.AddInputTensor(timestepTensor); + controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + controlNetParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds); + controlNetParameters.AddInputTensor(timeids); + controlNetParameters.AddInputTensor(controlImage); + if (controlNetMetadata.Inputs.Count == 5) + controlNetParameters.AddInputTensor(conditioningScale); + + // Optimization: Pre-allocate device buffers for inputs + foreach (var item in controlNetMetadata.Outputs) + controlNetParameters.AddOutputBuffer(); + + // ControlNet inference + var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters); + + // Add ControlNet outputs to Unet input + foreach (var item in controlNetResults) + inferenceParameters.AddInput(item); + + // Add output buffer + inferenceParameters.AddOutputBuffer(outputDimension); + + // Unet inference + var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters); + using (var result = results.First()) + { + var noisePred = result.ToDenseTensor(); + + // Perform guidance + if (performGuidance) + noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale); + + // Scheduler Step + latents = scheduler.Step(noisePred, timestep, latents).Result; + } + } + } + + ReportProgress(progressCallback, step, timesteps.Count, latents); + _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); + } + + // Decode Latents + return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents); + } + } + + + /// + /// Gets the timesteps. + /// + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + return scheduler.Timesteps; + } + + + /// + /// Prepares the input latents. + /// + /// The model. + /// The prompt. + /// The options. + /// The scheduler. + /// The timesteps. + /// + protected override Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma)); + } + + + /// + /// Creates the Conditioning Scale tensor. + /// + /// The conditioningScale. + /// + protected static DenseTensor CreateConditioningScaleTensor(float conditioningScale) + { + return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 }); + } + + + /// + /// Prepares the control image. + /// + /// The prompt options. + /// The scheduler options. + /// + protected async Task> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions) + { + var controlImage = promptOptions.InputContolImage; + if (schedulerOptions.IsControlImageProcessingEnabled) + { + controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width); + } + return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false); + } + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs new file mode 100644 index 00000000..e85dc7bf --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/ControlNetImageDiffuser.cs @@ -0,0 +1,83 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using SixLabors.ImageSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL +{ + public sealed class ControlNetImageDiffuser : ControlNetDiffuser + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger logger) + : base(onnxModelService, promptService, controlNetImageService, logger) + { + } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ControlNetImage; + + + /// + /// Gets the timesteps. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps); + var start = Math.Max(options.InferenceSteps - inittimestep, 0); + return scheduler.Timesteps.Skip(start).ToList(); + } + + + /// + /// Prepares the latents for inference. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override async Task> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); + + //TODO: Model Config, Channels + var outputDimension = options.GetScaledDimension(); + var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeEncoder); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(imageTensor); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeEncoder, inferenceParameters); + using (var result = results.First()) + { + var outputResult = result.ToDenseTensor(); + var scaledSample = outputResult.MultiplyBy(model.ScaleFactor); + return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); + } + } + } + + } +} diff --git a/OnnxStack.StableDiffusion/Registration.cs b/OnnxStack.StableDiffusion/Registration.cs index ad88d175..e9003d5f 100644 --- a/OnnxStack.StableDiffusion/Registration.cs +++ b/OnnxStack.StableDiffusion/Registration.cs @@ -69,19 +69,26 @@ private static void RegisterServices(this IServiceCollection serviceCollection) serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); //LatentConsistency serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); //LatentConsistencyXL serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); //InstaFlow serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); }