diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs
index 73af2ef8..537a2170 100644
--- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs
@@ -1,4 +1,6 @@
using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using OnnxStack.Core;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
@@ -6,6 +8,11 @@
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading.Tasks;
+using System.Threading;
+using System;
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
{
@@ -29,6 +36,92 @@ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistencyXL;
+ ///
+ /// Runs the scheduler steps.
+ ///
+ /// The model options.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The prompt embeddings.
+ /// if set to true [perform guidance].
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ public override async Task> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default)
+ {
+ // Get Scheduler
+ using (var scheduler = GetScheduler(schedulerOptions))
+ {
+ // Get timesteps
+ var timesteps = GetTimesteps(schedulerOptions, scheduler);
+
+ // Create latent sample
+ var latents = await PrepareLatentsAsync(promptOptions, schedulerOptions, scheduler, timesteps);
+
+ // Get Model metadata
+ var metadata = await _unet.GetMetadataAsync();
+
+ // Get Time ids
+ var addTimeIds = GetAddTimeIds(schedulerOptions);
+
+ // Get Guidance Scale Embedding
+ var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);
+
+ // Loop though the timesteps
+ var step = 0;
+ foreach (var timestep in timesteps)
+ {
+ step++;
+ var stepTime = Stopwatch.GetTimestamp();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Create input tensor.
+ var inputLatent = performGuidance ? latents.Repeat(2) : latents;
+ var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
+ var timestepTensor = CreateTimestepTensor(timestep);
+ var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
+
+ var outputChannels = performGuidance ? 2 : 1;
+ var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
+ {
+ inferenceParameters.AddInputTensor(inputTensor);
+ inferenceParameters.AddInputTensor(timestepTensor);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
+ if (inferenceParameters.InputCount == 6)
+ inferenceParameters.AddInputTensor(guidanceEmbeddings);
+ inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
+ inferenceParameters.AddInputTensor(timeids);
+ inferenceParameters.AddOutputBuffer(outputDimension);
+
+ var results = await _unet.RunInferenceAsync(inferenceParameters);
+ using (var result = results.First())
+ {
+ var noisePred = result.ToDenseTensor();
+
+ // Perform guidance
+ if (performGuidance)
+ noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
+
+ // Scheduler Step
+ latents = scheduler.Step(noisePred, timestep, latents).Result;
+ }
+ }
+
+ ReportProgress(progressCallback, step, timesteps.Count, latents);
+ _logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
+ }
+
+ // Unload if required
+ if (_memoryMode == MemoryModeType.Minimum)
+ await _unet.UnloadAsync();
+
+ // Decode Latents
+ return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
+ }
+ }
+
+
///
/// Gets the scheduler.
///
@@ -42,5 +135,26 @@ protected override IScheduler GetScheduler(SchedulerOptions options)
_ => default
};
}
+
+
+ ///
+ /// Gets the guidance scale embedding.
+ ///
+ /// The options.
+ /// The embedding dim.
+ ///
+ protected DenseTensor GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256)
+ {
+ var scale = (guidance - 1f) * 1000.0f;
+ var halfDim = embeddingDim / 2;
+ float log = MathF.Log(10000.0f) / (halfDim - 1);
+ var emb = Enumerable.Range(0, halfDim)
+ .Select(x => scale * MathF.Exp(-log * x))
+ .ToArray();
+ var embSin = emb.Select(MathF.Sin);
+ var embCos = emb.Select(MathF.Cos);
+ var guidanceEmbedding = embSin.Concat(embCos).ToArray();
+ return new DenseTensor(guidanceEmbedding, new[] { 1, embeddingDim });
+ }
}
}