diff --git a/OnnxStack.Console/Examples/FeatureExtractorVideoExample.cs b/OnnxStack.Console/Examples/FeatureExtractorVideoExample.cs new file mode 100644 index 00000000..467bfaab --- /dev/null +++ b/OnnxStack.Console/Examples/FeatureExtractorVideoExample.cs @@ -0,0 +1,50 @@ +using OnnxStack.Core.Video; +using OnnxStack.FeatureExtractor.Pipelines; +using OnnxStack.StableDiffusion.Config; + +namespace OnnxStack.Console.Runner +{ + public sealed class FeatureExtractorVideoExample : IExampleRunner + { + private readonly string _outputDirectory; + private readonly StableDiffusionConfig _configuration; + + public FeatureExtractorVideoExample(StableDiffusionConfig configuration) + { + _configuration = configuration; + _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(FeatureExtractorVideoExample)); + Directory.CreateDirectory(_outputDirectory); + } + + public int Index => 13; + + public string Name => "Feature Extractor Video Example"; + + public string Description => "Video exmaple using basic feature extractor"; + + /// + /// ControlNet Example + /// + public async Task RunAsync() + { + // Read Video + var videoFile = "C:\\Users\\Deven\\Pictures\\parrot.mp4"; + var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile); + + // Create pipeline + var pipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\canny.onnx"); + + // Create Video Stream + var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate); + + // Create Pipeline Stream + var pipelineStream = pipeline.RunAsync(videoStream); + + // Write Video Stream + await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4")); + + //Unload + await pipeline.UnloadAsync(); + } + } +} diff --git a/OnnxStack.Console/Examples/UpscaleExample.cs b/OnnxStack.Console/Examples/UpscaleExample.cs index 15360de1..d031b1a5 100644 --- a/OnnxStack.Console/Examples/UpscaleExample.cs +++ b/OnnxStack.Console/Examples/UpscaleExample.cs @@ -29,13 +29,10 @@ public async Task RunAsync() // Run pipeline var result = await pipeline.RunAsync(inputImage); - - // Create Image from Tensor result - var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne); - + // Save Image File var outputFilename = Path.Combine(_outputDirectory, $"Upscaled.png"); - await image.SaveAsync(outputFilename); + await result.SaveAsync(outputFilename); // Unload await pipeline.UnloadAsync(); diff --git a/OnnxStack.Console/Examples/UpscaleStreamExample.cs b/OnnxStack.Console/Examples/UpscaleStreamExample.cs new file mode 100644 index 00000000..0c52839b --- /dev/null +++ b/OnnxStack.Console/Examples/UpscaleStreamExample.cs @@ -0,0 +1,48 @@ +using OnnxStack.Core.Video; +using OnnxStack.FeatureExtractor.Pipelines; + +namespace OnnxStack.Console.Runner +{ + public sealed class UpscaleStreamExample : IExampleRunner + { + private readonly string _outputDirectory; + + public UpscaleStreamExample() + { + _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(UpscaleStreamExample)); + Directory.CreateDirectory(_outputDirectory); + } + + public int Index => 10; + + public string Name => "Streaming Video Upscale Demo"; + + public string Description => "Upscales a video stream"; + + public async Task RunAsync() + { + // Read Video + var videoFile = "C:\\Users\\Deven\\Pictures\\parrot.mp4"; + var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile); + + // Create pipeline + var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4); + + // Load pipeline + await pipeline.LoadAsync(); + + // Create Video Stream + var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate); + + // Create Pipeline Stream + var pipelineStream = pipeline.RunAsync(videoStream); + + // Write Video Stream + await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4")); + + //Unload + await pipeline.UnloadAsync(); + } + + } +} diff --git a/OnnxStack.Console/Examples/VideoToVideoStreamExample.cs b/OnnxStack.Console/Examples/VideoToVideoStreamExample.cs new file mode 100644 index 00000000..961a196e --- /dev/null +++ b/OnnxStack.Console/Examples/VideoToVideoStreamExample.cs @@ -0,0 +1,67 @@ +using OnnxStack.Core.Video; +using OnnxStack.FeatureExtractor.Pipelines; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Pipelines; + +namespace OnnxStack.Console.Runner +{ + public sealed class VideoToVideoStreamExample : IExampleRunner + { + private readonly string _outputDirectory; + private readonly StableDiffusionConfig _configuration; + + public VideoToVideoStreamExample(StableDiffusionConfig configuration) + { + _configuration = configuration; + _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(VideoToVideoStreamExample)); + Directory.CreateDirectory(_outputDirectory); + } + + public int Index => 4; + + public string Name => "Video To Video Stream Demo"; + + public string Description => "Video Stream Stable Diffusion Inference"; + + public async Task RunAsync() + { + + // Read Video + var videoFile = "C:\\Users\\Deven\\Pictures\\gidsgphy.gif"; + var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile); + + // Loop though the appsettings.json model sets + foreach (var modelSet in _configuration.ModelSets) + { + OutputHelpers.WriteConsole($"Loading Model `{modelSet.Name}`...", ConsoleColor.Cyan); + + // Create Pipeline + var pipeline = PipelineBase.CreatePipeline(modelSet); + + // Preload Models (optional) + await pipeline.LoadAsync(); + + // Add text and video to prompt + var promptOptions = new PromptOptions + { + Prompt = "Iron Man", + DiffuserType = DiffuserType.ImageToImage + }; + + + // Create Video Stream + var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate); + + // Create Pipeline Stream + var pipelineStream = pipeline.GenerateVideoStreamAsync(videoStream, promptOptions, progressCallback:OutputHelpers.ProgressCallback); + + // Write Video Stream + await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"{modelSet.PipelineType}.mp4")); + + //Unload + await pipeline.UnloadAsync(); + } + } + } +} diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index a6c6456d..c6dbce45 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -1,6 +1,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; +using System.Linq; namespace OnnxStack.Core { @@ -75,6 +76,33 @@ public static IEnumerable> SplitBatch(this DenseTensor } + /// + /// Joins the tensors across the 0 axis. + /// + /// The tensors. + /// The axis. + /// + /// Only axis 0 is supported + public static DenseTensor Join(this IList> tensors, int axis = 0) + { + if (axis != 0) + throw new NotImplementedException("Only axis 0 is supported"); + + var tensor = tensors.First(); + var dimensions = tensor.Dimensions.ToArray(); + dimensions[0] *= tensors.Count; + + var newLength = (int)tensor.Length; + var buffer = new float[newLength * tensors.Count].AsMemory(); + for (int i = 0; i < tensors.Count(); i++) + { + var start = i * newLength; + tensors[i].Buffer.CopyTo(buffer[start..]); + } + return new DenseTensor(buffer, dimensions); + } + + /// /// Concatenates the specified tensors along the specified axis. /// diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs index d4d837bb..d0998e8a 100644 --- a/OnnxStack.Core/Image/OnnxImage.cs +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -144,6 +144,20 @@ public byte[] GetImageBytes() } + /// + /// Gets the image as bytes. + /// + /// + public async Task GetImageBytesAsync() + { + using (var memoryStream = new MemoryStream()) + { + await _imageData.SaveAsPngAsync(memoryStream); + return memoryStream.ToArray(); + } + } + + /// /// Gets the image as stream. /// @@ -156,6 +170,40 @@ public Stream GetImageStream() } + /// + /// Gets the image as stream. + /// + /// + public async Task GetImageStreamAsync() + { + var memoryStream = new MemoryStream(); + await _imageData.SaveAsPngAsync(memoryStream); + return memoryStream; + } + + + /// + /// Copies the image to stream. + /// + /// The destination. + /// + public void CopyToStream(Stream destination) + { + _imageData.SaveAsPng(destination); + } + + + /// + /// Copies the image to stream. + /// + /// The destination. + /// + public Task CopyToStreamAsync(Stream destination) + { + return _imageData.SaveAsPngAsync(destination); + } + + /// /// Gets the image as tensor. /// diff --git a/OnnxStack.Core/Video/OnnxVideo.cs b/OnnxStack.Core/Video/OnnxVideo.cs index 0dd4c6f3..8569dd61 100644 --- a/OnnxStack.Core/Video/OnnxVideo.cs +++ b/OnnxStack.Core/Video/OnnxVideo.cs @@ -88,7 +88,7 @@ public OnnxVideo(VideoInfo info, IEnumerable> videoTensors) /// /// Gets the aspect ratio. /// - public double AspectRatio => (double)_info.Width / _info.Height; + public double AspectRatio => _info.AspectRatio; /// /// Gets a value indicating whether this instance has video. diff --git a/OnnxStack.Core/Video/VideoHelper.cs b/OnnxStack.Core/Video/VideoHelper.cs index 0b754733..a574837c 100644 --- a/OnnxStack.Core/Video/VideoHelper.cs +++ b/OnnxStack.Core/Video/VideoHelper.cs @@ -83,6 +83,36 @@ private static async Task WriteVideoFramesAsync(IEnumerable onnxImage } + /// + /// Writes the video stream to file. + /// + /// The onnx image stream. + /// The filename. + /// The frame rate. + /// The aspect ratio. + /// The cancellation token. + public static async Task WriteVideoStreamAsync(VideoInfo videoInfo, IAsyncEnumerable videoStream, string filename, CancellationToken cancellationToken = default) + { + if (File.Exists(filename)) + File.Delete(filename); + + using (var videoWriter = CreateWriter(filename, videoInfo.FrameRate, videoInfo.AspectRatio)) + { + // Start FFMPEG + videoWriter.Start(); + await foreach (var frame in videoStream) + { + // Write each frame to the input stream of FFMPEG + await frame.CopyToStreamAsync(videoWriter.StandardInput.BaseStream); + } + + // Done close stream and wait for app to process + videoWriter.StandardInput.BaseStream.Close(); + await videoWriter.WaitForExitAsync(cancellationToken); + } + } + + /// /// Reads the video information. /// @@ -119,9 +149,16 @@ public static async Task ReadVideoInfoAsync(string filename) /// public static async Task> ReadVideoFramesAsync(byte[] videoBytes, float frameRate = 15, CancellationToken cancellationToken = default) { - return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken) - .Select(x => new OnnxImage(x)) - .ToListAsync(cancellationToken); + string tempVideoPath = GetTempFilename(); + try + { + await File.WriteAllBytesAsync(tempVideoPath, videoBytes, cancellationToken); + return await ReadVideoStreamAsync(tempVideoPath, frameRate, cancellationToken).ToListAsync(cancellationToken); + } + finally + { + DeleteTempFile(tempVideoPath); + } } @@ -134,10 +171,23 @@ public static async Task> ReadVideoFramesAsync(byte[] videoBytes /// public static async Task> ReadVideoFramesAsync(string filename, float frameRate = 15, CancellationToken cancellationToken = default) { - var videoBytes = await File.ReadAllBytesAsync(filename, cancellationToken); - return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken) - .Select(x => new OnnxImage(x)) - .ToListAsync(cancellationToken); + return await ReadVideoStreamAsync(filename, frameRate, cancellationToken).ToListAsync(cancellationToken); + } + + + /// + /// Reads the video frames as a stream. + /// + /// The filename. + /// The frame rate. + /// The cancellation token. + /// + public static async IAsyncEnumerable ReadVideoStreamAsync(string filename, float frameRate = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var frameBytes in CreateFramesInternalAsync(filename, frameRate, cancellationToken)) + { + yield return new OnnxImage(frameBytes); + } } @@ -152,76 +202,67 @@ public static async Task> ReadVideoFramesAsync(string filename, /// The cancellation token. /// /// Invalid PNG header - private static async IAsyncEnumerable CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default) + private static async IAsyncEnumerable CreateFramesInternalAsync(string fileName, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - string tempVideoPath = GetTempFilename(); - try + using (var ffmpegProcess = CreateReader(fileName, fps)) { - await File.WriteAllBytesAsync(tempVideoPath, videoData, cancellationToken); - using (var ffmpegProcess = CreateReader(tempVideoPath, fps)) - { - // Start FFMPEG - ffmpegProcess.Start(); + // Start FFMPEG + ffmpegProcess.Start(); - // FFMPEG output stream - var processOutputStream = ffmpegProcess.StandardOutput.BaseStream; + // FFMPEG output stream + var processOutputStream = ffmpegProcess.StandardOutput.BaseStream; - // Buffer to hold the current image - var buffer = new byte[20480000]; + // Buffer to hold the current image + var buffer = new byte[20480000]; - var currentIndex = 0; - while (!cancellationToken.IsCancellationRequested) - { - // Reset the index new PNG - currentIndex = 0; + var currentIndex = 0; + while (!cancellationToken.IsCancellationRequested) + { + // Reset the index new PNG + currentIndex = 0; - // Read the PNG Header - if (await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 8), cancellationToken) <= 0) - break; + // Read the PNG Header + if (await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 8), cancellationToken) <= 0) + break; - currentIndex += 8;// header length + currentIndex += 8;// header length - if (!IsImageHeader(buffer)) - throw new Exception("Invalid PNG header"); + if (!IsImageHeader(buffer)) + throw new Exception("Invalid PNG header"); - // loop through each chunk - while (true) - { - // Read the chunk header - await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 12), cancellationToken); + // loop through each chunk + while (true) + { + // Read the chunk header + await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 12), cancellationToken); - var chunkIndex = currentIndex; - currentIndex += 12; // Chunk header length + var chunkIndex = currentIndex; + currentIndex += 12; // Chunk header length - // Get the chunk's content size in bytes from the header we just read - var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3]; - if (totalSize > 0) + // Get the chunk's content size in bytes from the header we just read + var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3]; + if (totalSize > 0) + { + var totalRead = 0; + while (totalRead < totalSize) { - var totalRead = 0; - while (totalRead < totalSize) - { - int read = await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, totalSize - totalRead), cancellationToken); - currentIndex += read; - totalRead += read; - } - continue; + int read = await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, totalSize - totalRead), cancellationToken); + currentIndex += read; + totalRead += read; } - - // If the size is 0 and is the end of the image - if (totalSize == 0 && IsImageEnd(buffer, chunkIndex)) - break; + continue; } - yield return buffer[..currentIndex]; + // If the size is 0 and is the end of the image + if (totalSize == 0 && IsImageEnd(buffer, chunkIndex)) + break; } - if (cancellationToken.IsCancellationRequested) - ffmpegProcess.Kill(); + yield return buffer[..currentIndex]; } - } - finally - { - DeleteTempFile(tempVideoPath); + + if (cancellationToken.IsCancellationRequested) + ffmpegProcess.Kill(); } } diff --git a/OnnxStack.Core/Video/VideoInfo.cs b/OnnxStack.Core/Video/VideoInfo.cs index 1a67870f..7b168558 100644 --- a/OnnxStack.Core/Video/VideoInfo.cs +++ b/OnnxStack.Core/Video/VideoInfo.cs @@ -11,5 +11,7 @@ public VideoInfo(int height, int width, TimeSpan duration, float frameRate) : th } public int Height { get; set; } public int Width { get; set; } + + public double AspectRatio => (double)Height / Width; } } diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs index 7c2f99e4..4bf6724c 100644 --- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs +++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -58,29 +59,9 @@ public async Task UnloadAsync() public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var timestamp = _logger?.LogBegin("Extracting image feature..."); - var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); - var metadata = await _featureExtractorModel.GetMetadataAsync(); - cancellationToken.ThrowIfCancellationRequested(); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(controlImage); - inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize }); - - var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - cancellationToken.ThrowIfCancellationRequested(); - - var resultTensor = result.ToDenseTensor(); - if (_featureExtractorModel.Normalize) - resultTensor.NormalizeMinMax(); - - var maskImage = resultTensor.ToImageMask(); - //await maskImage.SaveAsPngAsync("D:\\Mask.png"); - _logger?.LogEnd("Extracting image feature complete.", timestamp); - return maskImage; - } - } + var result = await RunInternalAsync(inputImage, cancellationToken); + _logger?.LogEnd("Extracting image feature complete.", timestamp); + return result; } @@ -92,34 +73,61 @@ public async Task RunAsync(OnnxImage inputImage, CancellationToken ca public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default) { var timestamp = _logger?.LogBegin("Extracting video features..."); + var featureFrames = new List(); + foreach (var videoFrame in video.Frames) + { + featureFrames.Add(await RunAsync(videoFrame, cancellationToken)); + } + _logger?.LogEnd("Extracting video features complete.", timestamp); + return new OnnxVideo(video.Info, featureFrames); + } + + + /// + /// Generates the feature extractor video stream + /// + /// The image frames. + /// The cancellation token. + /// + public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Extracting video stream features..."); + await foreach (var imageFrame in imageFrames) + { + yield return await RunInternalAsync(imageFrame, cancellationToken); + } + _logger?.LogEnd("Extracting video stream features complete.", timestamp); + } + + + /// + /// Runs the pipeline + /// + /// The input image. + /// The cancellation token. + /// + private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + { + var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); var metadata = await _featureExtractorModel.GetMetadataAsync(); cancellationToken.ThrowIfCancellationRequested(); - - var frames = new List(); - foreach (var videoFrame in video.Frames) + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { - var controlImage = await videoFrame.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(controlImage); - inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize }); + inferenceParameters.AddInputTensor(controlImage); + inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize }); - var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - cancellationToken.ThrowIfCancellationRequested(); + var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) + { + cancellationToken.ThrowIfCancellationRequested(); - var resultTensor = result.ToDenseTensor(); - if (_featureExtractorModel.Normalize) - resultTensor.NormalizeMinMax(); + var resultTensor = result.ToDenseTensor(); + if (_featureExtractorModel.Normalize) + resultTensor.NormalizeMinMax(); - var maskImage = resultTensor.ToImageMask(); - frames.Add(maskImage); - } + return resultTensor.ToImageMask(); } } - _logger?.LogEnd("Extracting video features complete.", timestamp); - return new OnnxVideo(video.Info, frames); } diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs index 964d259f..1761a916 100644 --- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs +++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs @@ -4,13 +4,16 @@ using OnnxStack.Core.Config; using OnnxStack.Core.Image; using OnnxStack.Core.Model; +using OnnxStack.Core.Video; using OnnxStack.ImageUpscaler.Common; using OnnxStack.ImageUpscaler.Extensions; using OnnxStack.ImageUpscaler.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; +using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -69,7 +72,60 @@ public async Task UnloadAsync() /// The input image. /// The cancellation token. /// - public async Task> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Upscale image.."); + var result = await RunInternalAsync(inputImage, cancellationToken); + _logger?.LogEnd("Upscale image complete.", timestamp); + return result; + } + + + /// + /// Runs the pipline on a buffered video. + /// + /// The input video. + /// The cancellation token. + /// + public async Task RunAsync(OnnxVideo inputVideo, CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Upscale video.."); + var upscaledFrames = new List(); + foreach (var videoFrame in inputVideo.Frames) + { + upscaledFrames.Add(await RunInternalAsync(videoFrame, cancellationToken)); + } + + var firstFrame = upscaledFrames.First(); + var videoInfo = inputVideo.Info with + { + Width = firstFrame.Width, + Height = firstFrame.Height, + }; + + _logger?.LogEnd("Upscale video complete.", timestamp); + return new OnnxVideo(videoInfo, upscaledFrames); + } + + + /// + /// Runs the pipline on a video stream. + /// + /// The image frames. + /// The cancellation token. + /// + public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Upscale video stream.."); + await foreach (var imageFrame in imageFrames) + { + yield return await RunInternalAsync(imageFrame, cancellationToken); + } + _logger?.LogEnd("Upscale video stream complete.", timestamp); + } + + + private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var upscaleInput = CreateInputParams(inputImage, _upscaleModel.SampleSize, _upscaleModel.ScaleFactor); var metadata = await _upscaleModel.GetMetadataAsync(); @@ -93,10 +149,16 @@ public async Task> RunAsync(OnnxImage inputImage, Cancellatio } } } - return outputTensor; + return new OnnxImage(outputTensor, ImageNormalizeType.ZeroToOne); } - + /// + /// Creates the input parameters. + /// + /// The image source. + /// Maximum size of the tile. + /// The scale factor. + /// private static UpscaleInput CreateInputParams(OnnxImage imageSource, int maxTileSize, int scaleFactor) { var tiles = imageSource.GenerateTiles(maxTileSize, scaleFactor); diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs index 2544b67d..46d38455 100644 --- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs +++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs @@ -5,7 +5,7 @@ namespace OnnxStack.StableDiffusion.Config { - public class PromptOptions + public record PromptOptions { public DiffuserType DiffuserType { get; set; } diff --git a/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs b/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs index 942dd36e..35b3152d 100644 --- a/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs +++ b/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs @@ -247,15 +247,6 @@ public static DenseTensor Clip(this DenseTensor tensor, float minV } - - - - - - - - - /// /// Generate a random Tensor from a normal distribution with mean 0 and variance 1 /// @@ -279,58 +270,5 @@ public static DenseTensor GetRandomTensor(Random random, ReadOnlySpan - /// Splits the Tensor along axis 0. - /// - /// The tensor. - /// The count. - /// The axis. - /// - /// Only axis 0 is supported - public static DenseTensor[] Split(this DenseTensor tensor, int count, int axis = 0) - { - if (axis != 0) - throw new NotImplementedException("Only axis 0 is supported"); - - var dimensions = tensor.Dimensions.ToArray(); - dimensions[0] /= count; - - var newLength = (int)tensor.Length / count; - var results = new DenseTensor[count]; - for (int i = 0; i < count; i++) - { - var start = i * newLength; - results[i] = new DenseTensor(tensor.Buffer.Slice(start, newLength), dimensions); - } - return results; - } - - - /// - /// Joins the tensors across the 0 axis. - /// - /// The tensors. - /// The axis. - /// - /// Only axis 0 is supported - public static DenseTensor Join(this IList> tensors, int axis = 0) - { - if (axis != 0) - throw new NotImplementedException("Only axis 0 is supported"); - - var tensor = tensors.First(); - var dimensions = tensor.Dimensions.ToArray(); - dimensions[0] *= tensors.Count; - - var newLength = (int)tensor.Length; - var buffer = new float[newLength * tensors.Count].AsMemory(); - for (int i = 0; i < tensors.Count(); i++) - { - var start = i * newLength; - tensors[i].Buffer.CopyTo(buffer[start..]); - } - return new DenseTensor(buffer, dimensions); - } } } diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs index 57727451..30dad861 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs @@ -14,6 +14,13 @@ namespace OnnxStack.StableDiffusion.Pipelines { public interface IPipeline { + + /// + /// Gets the name. + /// + string Name { get; } + + /// /// Gets the pipelines supported diffusers. /// @@ -127,5 +134,17 @@ public interface IPipeline /// The cancellation token. /// IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// + /// Runs the video stream pipeline returning each frame as an OnnxImage. + /// + /// The video frames. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + IAsyncEnumerable GenerateVideoStreamAsync(IAsyncEnumerable videoFrames, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs index 53b4b9e9..86a902d7 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs @@ -161,6 +161,19 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger) public abstract IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + /// + /// Runs the video stream pipeline returning each frame as an OnnxImage. + /// + /// The video frames. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public abstract IAsyncEnumerable GenerateVideoStreamAsync(IAsyncEnumerable videoFrames, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default); + + /// /// Creates the diffuser. /// diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs index ac9745c5..eda5a720 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs @@ -266,6 +266,49 @@ public override async IAsyncEnumerable GenerateVideoBatchAsync } + /// + /// Runs the video stream pipeline returning each frame as an OnnxImage. + /// + /// The video frames. + /// The prompt options. + /// The scheduler options. + /// The control net. + /// The progress callback. + /// The cancellation token. + /// + public override async IAsyncEnumerable GenerateVideoStreamAsync(IAsyncEnumerable videoFrames, PromptOptions promptOptions, SchedulerOptions schedulerOptions = null, ControlNetModel controlNet = null, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var diffuseTime = _logger?.LogBegin("Diffuser starting..."); + var options = GetSchedulerOptionsOrDefault(schedulerOptions); + _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}"); + + // Check guidance + var performGuidance = ShouldPerformGuidance(options); + + // Process prompts + var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance); + + // Create Diffuser + var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet); + + // Diffuse + await foreach (var videoFrame in videoFrames) + { + var frameOptions = promptOptions with + { + InputImage = promptOptions.DiffuserType == DiffuserType.ImageToImage || promptOptions.DiffuserType == DiffuserType.ControlNetImage + ? videoFrame : default, + InputContolImage = promptOptions.DiffuserType == DiffuserType.ControlNet || promptOptions.DiffuserType == DiffuserType.ControlNetImage + ? videoFrame : default, + }; + + yield return new OnnxImage(await DiffuseImageAsync(diffuser, frameOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken)); + } + + _logger?.LogEnd($"Diffuser complete", diffuseTime); + } + + /// /// Runs the pipeline /// diff --git a/OnnxStack.UI/Services/UpscaleService.cs b/OnnxStack.UI/Services/UpscaleService.cs index 36823225..06c1007d 100644 --- a/OnnxStack.UI/Services/UpscaleService.cs +++ b/OnnxStack.UI/Services/UpscaleService.cs @@ -1,7 +1,7 @@ using Microsoft.Extensions.Logging; -using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core.Config; using OnnxStack.Core.Image; +using OnnxStack.Core.Video; using OnnxStack.FeatureExtractor.Pipelines; using OnnxStack.ImageUpscaler.Common; using System; @@ -13,7 +13,7 @@ namespace OnnxStack.UI.Services { public class UpscaleService : IUpscaleService { - private readonly ILogger _logger; + private readonly ILogger _logger; private readonly Dictionary _pipelines; /// @@ -22,8 +22,9 @@ public class UpscaleService : IUpscaleService /// The configuration. /// The model service. /// The image service. - public UpscaleService() + public UpscaleService(ILogger logger) { + _logger = logger; _pipelines = new Dictionary(); } @@ -76,33 +77,43 @@ public bool IsModelLoaded(UpscaleModelSet modelOptions) /// /// Generates the upscaled image. /// - /// The model options. + /// The model options. /// The input image. /// - public async Task GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default) + public async Task GenerateAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken = default) { - return new OnnxImage(await GenerateInternalAsync(modelOptions, inputImage, cancellationToken), ImageNormalizeType.ZeroToOne); + if (!_pipelines.TryGetValue(modelSet, out var pipeline)) + throw new Exception("Pipeline not found or is unsupported"); + + return await pipeline.RunAsync(inputImage, cancellationToken); } /// - /// Generates an upscaled image of the source provided. + /// Generates the upscaled video. /// - /// The model options. - /// The input image. - private async Task> GenerateInternalAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken) + /// The model set. + /// The input video. + /// The cancellation token. + /// + /// Pipeline not found or is unsupported + public async Task GenerateAsync(UpscaleModelSet modelSet, OnnxVideo inputVideo, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelSet, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); - return await pipeline.RunAsync(inputImage, cancellationToken); + return await pipeline.RunAsync(inputVideo, cancellationToken); } + /// + /// Creates the pipeline. + /// + /// The model set. + /// private ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet) { return ImageUpscalePipeline.CreatePipeline(modelSet, _logger); } - } } \ No newline at end of file