diff --git a/OnnxStack.Console/Examples/FeatureExtractorVideoExample.cs b/OnnxStack.Console/Examples/FeatureExtractorVideoExample.cs
new file mode 100644
index 00000000..467bfaab
--- /dev/null
+++ b/OnnxStack.Console/Examples/FeatureExtractorVideoExample.cs
@@ -0,0 +1,50 @@
+using OnnxStack.Core.Video;
+using OnnxStack.FeatureExtractor.Pipelines;
+using OnnxStack.StableDiffusion.Config;
+
+namespace OnnxStack.Console.Runner
+{
+ public sealed class FeatureExtractorVideoExample : IExampleRunner
+ {
+ private readonly string _outputDirectory;
+ private readonly StableDiffusionConfig _configuration;
+
+ public FeatureExtractorVideoExample(StableDiffusionConfig configuration)
+ {
+ _configuration = configuration;
+ _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(FeatureExtractorVideoExample));
+ Directory.CreateDirectory(_outputDirectory);
+ }
+
+ public int Index => 13;
+
+ public string Name => "Feature Extractor Video Example";
+
+ public string Description => "Video exmaple using basic feature extractor";
+
+ ///
+ /// ControlNet Example
+ ///
+ public async Task RunAsync()
+ {
+ // Read Video
+ var videoFile = "C:\\Users\\Deven\\Pictures\\parrot.mp4";
+ var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile);
+
+ // Create pipeline
+ var pipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\canny.onnx");
+
+ // Create Video Stream
+ var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate);
+
+ // Create Pipeline Stream
+ var pipelineStream = pipeline.RunAsync(videoStream);
+
+ // Write Video Stream
+ await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4"));
+
+ //Unload
+ await pipeline.UnloadAsync();
+ }
+ }
+}
diff --git a/OnnxStack.Console/Examples/UpscaleExample.cs b/OnnxStack.Console/Examples/UpscaleExample.cs
index 15360de1..d031b1a5 100644
--- a/OnnxStack.Console/Examples/UpscaleExample.cs
+++ b/OnnxStack.Console/Examples/UpscaleExample.cs
@@ -29,13 +29,10 @@ public async Task RunAsync()
// Run pipeline
var result = await pipeline.RunAsync(inputImage);
-
- // Create Image from Tensor result
- var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne);
-
+
// Save Image File
var outputFilename = Path.Combine(_outputDirectory, $"Upscaled.png");
- await image.SaveAsync(outputFilename);
+ await result.SaveAsync(outputFilename);
// Unload
await pipeline.UnloadAsync();
diff --git a/OnnxStack.Console/Examples/UpscaleStreamExample.cs b/OnnxStack.Console/Examples/UpscaleStreamExample.cs
new file mode 100644
index 00000000..0c52839b
--- /dev/null
+++ b/OnnxStack.Console/Examples/UpscaleStreamExample.cs
@@ -0,0 +1,48 @@
+using OnnxStack.Core.Video;
+using OnnxStack.FeatureExtractor.Pipelines;
+
+namespace OnnxStack.Console.Runner
+{
+ public sealed class UpscaleStreamExample : IExampleRunner
+ {
+ private readonly string _outputDirectory;
+
+ public UpscaleStreamExample()
+ {
+ _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(UpscaleStreamExample));
+ Directory.CreateDirectory(_outputDirectory);
+ }
+
+ public int Index => 10;
+
+ public string Name => "Streaming Video Upscale Demo";
+
+ public string Description => "Upscales a video stream";
+
+ public async Task RunAsync()
+ {
+ // Read Video
+ var videoFile = "C:\\Users\\Deven\\Pictures\\parrot.mp4";
+ var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile);
+
+ // Create pipeline
+ var pipeline = ImageUpscalePipeline.CreatePipeline("D:\\Repositories\\upscaler\\SwinIR\\003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx", 4);
+
+ // Load pipeline
+ await pipeline.LoadAsync();
+
+ // Create Video Stream
+ var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate);
+
+ // Create Pipeline Stream
+ var pipelineStream = pipeline.RunAsync(videoStream);
+
+ // Write Video Stream
+ await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4"));
+
+ //Unload
+ await pipeline.UnloadAsync();
+ }
+
+ }
+}
diff --git a/OnnxStack.Console/Examples/VideoToVideoStreamExample.cs b/OnnxStack.Console/Examples/VideoToVideoStreamExample.cs
new file mode 100644
index 00000000..961a196e
--- /dev/null
+++ b/OnnxStack.Console/Examples/VideoToVideoStreamExample.cs
@@ -0,0 +1,67 @@
+using OnnxStack.Core.Video;
+using OnnxStack.FeatureExtractor.Pipelines;
+using OnnxStack.StableDiffusion.Config;
+using OnnxStack.StableDiffusion.Enums;
+using OnnxStack.StableDiffusion.Pipelines;
+
+namespace OnnxStack.Console.Runner
+{
+ public sealed class VideoToVideoStreamExample : IExampleRunner
+ {
+ private readonly string _outputDirectory;
+ private readonly StableDiffusionConfig _configuration;
+
+ public VideoToVideoStreamExample(StableDiffusionConfig configuration)
+ {
+ _configuration = configuration;
+ _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(VideoToVideoStreamExample));
+ Directory.CreateDirectory(_outputDirectory);
+ }
+
+ public int Index => 4;
+
+ public string Name => "Video To Video Stream Demo";
+
+ public string Description => "Video Stream Stable Diffusion Inference";
+
+ public async Task RunAsync()
+ {
+
+ // Read Video
+ var videoFile = "C:\\Users\\Deven\\Pictures\\gidsgphy.gif";
+ var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile);
+
+ // Loop though the appsettings.json model sets
+ foreach (var modelSet in _configuration.ModelSets)
+ {
+ OutputHelpers.WriteConsole($"Loading Model `{modelSet.Name}`...", ConsoleColor.Cyan);
+
+ // Create Pipeline
+ var pipeline = PipelineBase.CreatePipeline(modelSet);
+
+ // Preload Models (optional)
+ await pipeline.LoadAsync();
+
+ // Add text and video to prompt
+ var promptOptions = new PromptOptions
+ {
+ Prompt = "Iron Man",
+ DiffuserType = DiffuserType.ImageToImage
+ };
+
+
+ // Create Video Stream
+ var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate);
+
+ // Create Pipeline Stream
+ var pipelineStream = pipeline.GenerateVideoStreamAsync(videoStream, promptOptions, progressCallback:OutputHelpers.ProgressCallback);
+
+ // Write Video Stream
+ await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"{modelSet.PipelineType}.mp4"));
+
+ //Unload
+ await pipeline.UnloadAsync();
+ }
+ }
+ }
+}
diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs
index a6c6456d..c6dbce45 100644
--- a/OnnxStack.Core/Extensions/TensorExtension.cs
+++ b/OnnxStack.Core/Extensions/TensorExtension.cs
@@ -1,6 +1,7 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
+using System.Linq;
namespace OnnxStack.Core
{
@@ -75,6 +76,33 @@ public static IEnumerable> SplitBatch(this DenseTensor
}
+ ///
+ /// Joins the tensors across the 0 axis.
+ ///
+ /// The tensors.
+ /// The axis.
+ ///
+ /// Only axis 0 is supported
+ public static DenseTensor Join(this IList> tensors, int axis = 0)
+ {
+ if (axis != 0)
+ throw new NotImplementedException("Only axis 0 is supported");
+
+ var tensor = tensors.First();
+ var dimensions = tensor.Dimensions.ToArray();
+ dimensions[0] *= tensors.Count;
+
+ var newLength = (int)tensor.Length;
+ var buffer = new float[newLength * tensors.Count].AsMemory();
+ for (int i = 0; i < tensors.Count(); i++)
+ {
+ var start = i * newLength;
+ tensors[i].Buffer.CopyTo(buffer[start..]);
+ }
+ return new DenseTensor(buffer, dimensions);
+ }
+
+
///
/// Concatenates the specified tensors along the specified axis.
///
diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs
index d4d837bb..d0998e8a 100644
--- a/OnnxStack.Core/Image/OnnxImage.cs
+++ b/OnnxStack.Core/Image/OnnxImage.cs
@@ -144,6 +144,20 @@ public byte[] GetImageBytes()
}
+ ///
+ /// Gets the image as bytes.
+ ///
+ ///
+ public async Task GetImageBytesAsync()
+ {
+ using (var memoryStream = new MemoryStream())
+ {
+ await _imageData.SaveAsPngAsync(memoryStream);
+ return memoryStream.ToArray();
+ }
+ }
+
+
///
/// Gets the image as stream.
///
@@ -156,6 +170,40 @@ public Stream GetImageStream()
}
+ ///
+ /// Gets the image as stream.
+ ///
+ ///
+ public async Task GetImageStreamAsync()
+ {
+ var memoryStream = new MemoryStream();
+ await _imageData.SaveAsPngAsync(memoryStream);
+ return memoryStream;
+ }
+
+
+ ///
+ /// Copies the image to stream.
+ ///
+ /// The destination.
+ ///
+ public void CopyToStream(Stream destination)
+ {
+ _imageData.SaveAsPng(destination);
+ }
+
+
+ ///
+ /// Copies the image to stream.
+ ///
+ /// The destination.
+ ///
+ public Task CopyToStreamAsync(Stream destination)
+ {
+ return _imageData.SaveAsPngAsync(destination);
+ }
+
+
///
/// Gets the image as tensor.
///
diff --git a/OnnxStack.Core/Video/OnnxVideo.cs b/OnnxStack.Core/Video/OnnxVideo.cs
index 0dd4c6f3..8569dd61 100644
--- a/OnnxStack.Core/Video/OnnxVideo.cs
+++ b/OnnxStack.Core/Video/OnnxVideo.cs
@@ -88,7 +88,7 @@ public OnnxVideo(VideoInfo info, IEnumerable> videoTensors)
///
/// Gets the aspect ratio.
///
- public double AspectRatio => (double)_info.Width / _info.Height;
+ public double AspectRatio => _info.AspectRatio;
///
/// Gets a value indicating whether this instance has video.
diff --git a/OnnxStack.Core/Video/VideoHelper.cs b/OnnxStack.Core/Video/VideoHelper.cs
index 0b754733..a574837c 100644
--- a/OnnxStack.Core/Video/VideoHelper.cs
+++ b/OnnxStack.Core/Video/VideoHelper.cs
@@ -83,6 +83,36 @@ private static async Task WriteVideoFramesAsync(IEnumerable onnxImage
}
+ ///
+ /// Writes the video stream to file.
+ ///
+ /// The onnx image stream.
+ /// The filename.
+ /// The frame rate.
+ /// The aspect ratio.
+ /// The cancellation token.
+ public static async Task WriteVideoStreamAsync(VideoInfo videoInfo, IAsyncEnumerable videoStream, string filename, CancellationToken cancellationToken = default)
+ {
+ if (File.Exists(filename))
+ File.Delete(filename);
+
+ using (var videoWriter = CreateWriter(filename, videoInfo.FrameRate, videoInfo.AspectRatio))
+ {
+ // Start FFMPEG
+ videoWriter.Start();
+ await foreach (var frame in videoStream)
+ {
+ // Write each frame to the input stream of FFMPEG
+ await frame.CopyToStreamAsync(videoWriter.StandardInput.BaseStream);
+ }
+
+ // Done close stream and wait for app to process
+ videoWriter.StandardInput.BaseStream.Close();
+ await videoWriter.WaitForExitAsync(cancellationToken);
+ }
+ }
+
+
///
/// Reads the video information.
///
@@ -119,9 +149,16 @@ public static async Task ReadVideoInfoAsync(string filename)
///
public static async Task> ReadVideoFramesAsync(byte[] videoBytes, float frameRate = 15, CancellationToken cancellationToken = default)
{
- return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken)
- .Select(x => new OnnxImage(x))
- .ToListAsync(cancellationToken);
+ string tempVideoPath = GetTempFilename();
+ try
+ {
+ await File.WriteAllBytesAsync(tempVideoPath, videoBytes, cancellationToken);
+ return await ReadVideoStreamAsync(tempVideoPath, frameRate, cancellationToken).ToListAsync(cancellationToken);
+ }
+ finally
+ {
+ DeleteTempFile(tempVideoPath);
+ }
}
@@ -134,10 +171,23 @@ public static async Task> ReadVideoFramesAsync(byte[] videoBytes
///
public static async Task> ReadVideoFramesAsync(string filename, float frameRate = 15, CancellationToken cancellationToken = default)
{
- var videoBytes = await File.ReadAllBytesAsync(filename, cancellationToken);
- return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken)
- .Select(x => new OnnxImage(x))
- .ToListAsync(cancellationToken);
+ return await ReadVideoStreamAsync(filename, frameRate, cancellationToken).ToListAsync(cancellationToken);
+ }
+
+
+ ///
+ /// Reads the video frames as a stream.
+ ///
+ /// The filename.
+ /// The frame rate.
+ /// The cancellation token.
+ ///
+ public static async IAsyncEnumerable ReadVideoStreamAsync(string filename, float frameRate = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ await foreach (var frameBytes in CreateFramesInternalAsync(filename, frameRate, cancellationToken))
+ {
+ yield return new OnnxImage(frameBytes);
+ }
}
@@ -152,76 +202,67 @@ public static async Task> ReadVideoFramesAsync(string filename,
/// The cancellation token.
///
/// Invalid PNG header
- private static async IAsyncEnumerable CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ private static async IAsyncEnumerable CreateFramesInternalAsync(string fileName, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- string tempVideoPath = GetTempFilename();
- try
+ using (var ffmpegProcess = CreateReader(fileName, fps))
{
- await File.WriteAllBytesAsync(tempVideoPath, videoData, cancellationToken);
- using (var ffmpegProcess = CreateReader(tempVideoPath, fps))
- {
- // Start FFMPEG
- ffmpegProcess.Start();
+ // Start FFMPEG
+ ffmpegProcess.Start();
- // FFMPEG output stream
- var processOutputStream = ffmpegProcess.StandardOutput.BaseStream;
+ // FFMPEG output stream
+ var processOutputStream = ffmpegProcess.StandardOutput.BaseStream;
- // Buffer to hold the current image
- var buffer = new byte[20480000];
+ // Buffer to hold the current image
+ var buffer = new byte[20480000];
- var currentIndex = 0;
- while (!cancellationToken.IsCancellationRequested)
- {
- // Reset the index new PNG
- currentIndex = 0;
+ var currentIndex = 0;
+ while (!cancellationToken.IsCancellationRequested)
+ {
+ // Reset the index new PNG
+ currentIndex = 0;
- // Read the PNG Header
- if (await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 8), cancellationToken) <= 0)
- break;
+ // Read the PNG Header
+ if (await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 8), cancellationToken) <= 0)
+ break;
- currentIndex += 8;// header length
+ currentIndex += 8;// header length
- if (!IsImageHeader(buffer))
- throw new Exception("Invalid PNG header");
+ if (!IsImageHeader(buffer))
+ throw new Exception("Invalid PNG header");
- // loop through each chunk
- while (true)
- {
- // Read the chunk header
- await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 12), cancellationToken);
+ // loop through each chunk
+ while (true)
+ {
+ // Read the chunk header
+ await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, 12), cancellationToken);
- var chunkIndex = currentIndex;
- currentIndex += 12; // Chunk header length
+ var chunkIndex = currentIndex;
+ currentIndex += 12; // Chunk header length
- // Get the chunk's content size in bytes from the header we just read
- var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3];
- if (totalSize > 0)
+ // Get the chunk's content size in bytes from the header we just read
+ var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3];
+ if (totalSize > 0)
+ {
+ var totalRead = 0;
+ while (totalRead < totalSize)
{
- var totalRead = 0;
- while (totalRead < totalSize)
- {
- int read = await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, totalSize - totalRead), cancellationToken);
- currentIndex += read;
- totalRead += read;
- }
- continue;
+ int read = await processOutputStream.ReadAsync(buffer.AsMemory(currentIndex, totalSize - totalRead), cancellationToken);
+ currentIndex += read;
+ totalRead += read;
}
-
- // If the size is 0 and is the end of the image
- if (totalSize == 0 && IsImageEnd(buffer, chunkIndex))
- break;
+ continue;
}
- yield return buffer[..currentIndex];
+ // If the size is 0 and is the end of the image
+ if (totalSize == 0 && IsImageEnd(buffer, chunkIndex))
+ break;
}
- if (cancellationToken.IsCancellationRequested)
- ffmpegProcess.Kill();
+ yield return buffer[..currentIndex];
}
- }
- finally
- {
- DeleteTempFile(tempVideoPath);
+
+ if (cancellationToken.IsCancellationRequested)
+ ffmpegProcess.Kill();
}
}
diff --git a/OnnxStack.Core/Video/VideoInfo.cs b/OnnxStack.Core/Video/VideoInfo.cs
index 1a67870f..7b168558 100644
--- a/OnnxStack.Core/Video/VideoInfo.cs
+++ b/OnnxStack.Core/Video/VideoInfo.cs
@@ -11,5 +11,7 @@ public VideoInfo(int height, int width, TimeSpan duration, float frameRate) : th
}
public int Height { get; set; }
public int Width { get; set; }
+
+ public double AspectRatio => (double)Height / Width;
}
}
diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs
index 7c2f99e4..4bf6724c 100644
--- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs
+++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs
@@ -8,6 +8,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -58,29 +59,9 @@ public async Task UnloadAsync()
public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
{
var timestamp = _logger?.LogBegin("Extracting image feature...");
- var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne);
- var metadata = await _featureExtractorModel.GetMetadataAsync();
- cancellationToken.ThrowIfCancellationRequested();
- using (var inferenceParameters = new OnnxInferenceParameters(metadata))
- {
- inferenceParameters.AddInputTensor(controlImage);
- inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize });
-
- var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters);
- using (var result = results.First())
- {
- cancellationToken.ThrowIfCancellationRequested();
-
- var resultTensor = result.ToDenseTensor();
- if (_featureExtractorModel.Normalize)
- resultTensor.NormalizeMinMax();
-
- var maskImage = resultTensor.ToImageMask();
- //await maskImage.SaveAsPngAsync("D:\\Mask.png");
- _logger?.LogEnd("Extracting image feature complete.", timestamp);
- return maskImage;
- }
- }
+ var result = await RunInternalAsync(inputImage, cancellationToken);
+ _logger?.LogEnd("Extracting image feature complete.", timestamp);
+ return result;
}
@@ -92,34 +73,61 @@ public async Task RunAsync(OnnxImage inputImage, CancellationToken ca
public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default)
{
var timestamp = _logger?.LogBegin("Extracting video features...");
+ var featureFrames = new List();
+ foreach (var videoFrame in video.Frames)
+ {
+ featureFrames.Add(await RunAsync(videoFrame, cancellationToken));
+ }
+ _logger?.LogEnd("Extracting video features complete.", timestamp);
+ return new OnnxVideo(video.Info, featureFrames);
+ }
+
+
+ ///
+ /// Generates the feature extractor video stream
+ ///
+ /// The image frames.
+ /// The cancellation token.
+ ///
+ public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Extracting video stream features...");
+ await foreach (var imageFrame in imageFrames)
+ {
+ yield return await RunInternalAsync(imageFrame, cancellationToken);
+ }
+ _logger?.LogEnd("Extracting video stream features complete.", timestamp);
+ }
+
+
+ ///
+ /// Runs the pipeline
+ ///
+ /// The input image.
+ /// The cancellation token.
+ ///
+ private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
+ {
+ var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne);
var metadata = await _featureExtractorModel.GetMetadataAsync();
cancellationToken.ThrowIfCancellationRequested();
-
- var frames = new List();
- foreach (var videoFrame in video.Frames)
+ using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
- var controlImage = await videoFrame.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne);
- using (var inferenceParameters = new OnnxInferenceParameters(metadata))
- {
- inferenceParameters.AddInputTensor(controlImage);
- inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize });
+ inferenceParameters.AddInputTensor(controlImage);
+ inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize });
- var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters);
- using (var result = results.First())
- {
- cancellationToken.ThrowIfCancellationRequested();
+ var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters);
+ using (var result = results.First())
+ {
+ cancellationToken.ThrowIfCancellationRequested();
- var resultTensor = result.ToDenseTensor();
- if (_featureExtractorModel.Normalize)
- resultTensor.NormalizeMinMax();
+ var resultTensor = result.ToDenseTensor();
+ if (_featureExtractorModel.Normalize)
+ resultTensor.NormalizeMinMax();
- var maskImage = resultTensor.ToImageMask();
- frames.Add(maskImage);
- }
+ return resultTensor.ToImageMask();
}
}
- _logger?.LogEnd("Extracting video features complete.", timestamp);
- return new OnnxVideo(video.Info, frames);
}
diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs
index 964d259f..1761a916 100644
--- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs
+++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs
@@ -4,13 +4,16 @@
using OnnxStack.Core.Config;
using OnnxStack.Core.Image;
using OnnxStack.Core.Model;
+using OnnxStack.Core.Video;
using OnnxStack.ImageUpscaler.Common;
using OnnxStack.ImageUpscaler.Extensions;
using OnnxStack.ImageUpscaler.Models;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
+using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -69,7 +72,60 @@ public async Task UnloadAsync()
/// The input image.
/// The cancellation token.
///
- public async Task> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
+ public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Upscale image..");
+ var result = await RunInternalAsync(inputImage, cancellationToken);
+ _logger?.LogEnd("Upscale image complete.", timestamp);
+ return result;
+ }
+
+
+ ///
+ /// Runs the pipline on a buffered video.
+ ///
+ /// The input video.
+ /// The cancellation token.
+ ///
+ public async Task RunAsync(OnnxVideo inputVideo, CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Upscale video..");
+ var upscaledFrames = new List();
+ foreach (var videoFrame in inputVideo.Frames)
+ {
+ upscaledFrames.Add(await RunInternalAsync(videoFrame, cancellationToken));
+ }
+
+ var firstFrame = upscaledFrames.First();
+ var videoInfo = inputVideo.Info with
+ {
+ Width = firstFrame.Width,
+ Height = firstFrame.Height,
+ };
+
+ _logger?.LogEnd("Upscale video complete.", timestamp);
+ return new OnnxVideo(videoInfo, upscaledFrames);
+ }
+
+
+ ///
+ /// Runs the pipline on a video stream.
+ ///
+ /// The image frames.
+ /// The cancellation token.
+ ///
+ public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Upscale video stream..");
+ await foreach (var imageFrame in imageFrames)
+ {
+ yield return await RunInternalAsync(imageFrame, cancellationToken);
+ }
+ _logger?.LogEnd("Upscale video stream complete.", timestamp);
+ }
+
+
+ private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
{
var upscaleInput = CreateInputParams(inputImage, _upscaleModel.SampleSize, _upscaleModel.ScaleFactor);
var metadata = await _upscaleModel.GetMetadataAsync();
@@ -93,10 +149,16 @@ public async Task> RunAsync(OnnxImage inputImage, Cancellatio
}
}
}
- return outputTensor;
+ return new OnnxImage(outputTensor, ImageNormalizeType.ZeroToOne);
}
-
+ ///
+ /// Creates the input parameters.
+ ///
+ /// The image source.
+ /// Maximum size of the tile.
+ /// The scale factor.
+ ///
private static UpscaleInput CreateInputParams(OnnxImage imageSource, int maxTileSize, int scaleFactor)
{
var tiles = imageSource.GenerateTiles(maxTileSize, scaleFactor);
diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs
index 2544b67d..46d38455 100644
--- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs
+++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs
@@ -5,7 +5,7 @@
namespace OnnxStack.StableDiffusion.Config
{
- public class PromptOptions
+ public record PromptOptions
{
public DiffuserType DiffuserType { get; set; }
diff --git a/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs b/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs
index 942dd36e..35b3152d 100644
--- a/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs
+++ b/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs
@@ -247,15 +247,6 @@ public static DenseTensor Clip(this DenseTensor tensor, float minV
}
-
-
-
-
-
-
-
-
-
///
/// Generate a random Tensor from a normal distribution with mean 0 and variance 1
///
@@ -279,58 +270,5 @@ public static DenseTensor GetRandomTensor(Random random, ReadOnlySpan
- /// Splits the Tensor along axis 0.
- ///
- /// The tensor.
- /// The count.
- /// The axis.
- ///
- /// Only axis 0 is supported
- public static DenseTensor[] Split(this DenseTensor tensor, int count, int axis = 0)
- {
- if (axis != 0)
- throw new NotImplementedException("Only axis 0 is supported");
-
- var dimensions = tensor.Dimensions.ToArray();
- dimensions[0] /= count;
-
- var newLength = (int)tensor.Length / count;
- var results = new DenseTensor[count];
- for (int i = 0; i < count; i++)
- {
- var start = i * newLength;
- results[i] = new DenseTensor(tensor.Buffer.Slice(start, newLength), dimensions);
- }
- return results;
- }
-
-
- ///
- /// Joins the tensors across the 0 axis.
- ///
- /// The tensors.
- /// The axis.
- ///
- /// Only axis 0 is supported
- public static DenseTensor Join(this IList> tensors, int axis = 0)
- {
- if (axis != 0)
- throw new NotImplementedException("Only axis 0 is supported");
-
- var tensor = tensors.First();
- var dimensions = tensor.Dimensions.ToArray();
- dimensions[0] *= tensors.Count;
-
- var newLength = (int)tensor.Length;
- var buffer = new float[newLength * tensors.Count].AsMemory();
- for (int i = 0; i < tensors.Count(); i++)
- {
- var start = i * newLength;
- tensors[i].Buffer.CopyTo(buffer[start..]);
- }
- return new DenseTensor(buffer, dimensions);
- }
}
}
diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs
index 57727451..30dad861 100644
--- a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs
@@ -14,6 +14,13 @@ namespace OnnxStack.StableDiffusion.Pipelines
{
public interface IPipeline
{
+
+ ///
+ /// Gets the name.
+ ///
+ string Name { get; }
+
+
///
/// Gets the pipelines supported diffusers.
///
@@ -127,5 +134,17 @@ public interface IPipeline
/// The cancellation token.
///
IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default);
+
+ ///
+ /// Runs the video stream pipeline returning each frame as an OnnxImage.
+ ///
+ /// The video frames.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The control net.
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ IAsyncEnumerable GenerateVideoStreamAsync(IAsyncEnumerable videoFrames, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default);
}
}
\ No newline at end of file
diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
index 53b4b9e9..86a902d7 100644
--- a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
@@ -161,6 +161,19 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
public abstract IAsyncEnumerable GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default);
+ ///
+ /// Runs the video stream pipeline returning each frame as an OnnxImage.
+ ///
+ /// The video frames.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The control net.
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ public abstract IAsyncEnumerable GenerateVideoStreamAsync(IAsyncEnumerable videoFrames, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action progressCallback = null, CancellationToken cancellationToken = default);
+
+
///
/// Creates the diffuser.
///
diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
index ac9745c5..eda5a720 100644
--- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
+++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
@@ -266,6 +266,49 @@ public override async IAsyncEnumerable GenerateVideoBatchAsync
}
+ ///
+ /// Runs the video stream pipeline returning each frame as an OnnxImage.
+ ///
+ /// The video frames.
+ /// The prompt options.
+ /// The scheduler options.
+ /// The control net.
+ /// The progress callback.
+ /// The cancellation token.
+ ///
+ public override async IAsyncEnumerable GenerateVideoStreamAsync(IAsyncEnumerable videoFrames, PromptOptions promptOptions, SchedulerOptions schedulerOptions = null, ControlNetModel controlNet = null, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var diffuseTime = _logger?.LogBegin("Diffuser starting...");
+ var options = GetSchedulerOptionsOrDefault(schedulerOptions);
+ _logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}");
+
+ // Check guidance
+ var performGuidance = ShouldPerformGuidance(options);
+
+ // Process prompts
+ var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance);
+
+ // Create Diffuser
+ var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
+
+ // Diffuse
+ await foreach (var videoFrame in videoFrames)
+ {
+ var frameOptions = promptOptions with
+ {
+ InputImage = promptOptions.DiffuserType == DiffuserType.ImageToImage || promptOptions.DiffuserType == DiffuserType.ControlNetImage
+ ? videoFrame : default,
+ InputContolImage = promptOptions.DiffuserType == DiffuserType.ControlNet || promptOptions.DiffuserType == DiffuserType.ControlNetImage
+ ? videoFrame : default,
+ };
+
+ yield return new OnnxImage(await DiffuseImageAsync(diffuser, frameOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken));
+ }
+
+ _logger?.LogEnd($"Diffuser complete", diffuseTime);
+ }
+
+
///
/// Runs the pipeline
///
diff --git a/OnnxStack.UI/Services/UpscaleService.cs b/OnnxStack.UI/Services/UpscaleService.cs
index 36823225..06c1007d 100644
--- a/OnnxStack.UI/Services/UpscaleService.cs
+++ b/OnnxStack.UI/Services/UpscaleService.cs
@@ -1,7 +1,7 @@
using Microsoft.Extensions.Logging;
-using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Config;
using OnnxStack.Core.Image;
+using OnnxStack.Core.Video;
using OnnxStack.FeatureExtractor.Pipelines;
using OnnxStack.ImageUpscaler.Common;
using System;
@@ -13,7 +13,7 @@ namespace OnnxStack.UI.Services
{
public class UpscaleService : IUpscaleService
{
- private readonly ILogger _logger;
+ private readonly ILogger _logger;
private readonly Dictionary _pipelines;
///
@@ -22,8 +22,9 @@ public class UpscaleService : IUpscaleService
/// The configuration.
/// The model service.
/// The image service.
- public UpscaleService()
+ public UpscaleService(ILogger logger)
{
+ _logger = logger;
_pipelines = new Dictionary();
}
@@ -76,33 +77,43 @@ public bool IsModelLoaded(UpscaleModelSet modelOptions)
///
/// Generates the upscaled image.
///
- /// The model options.
+ /// The model options.
/// The input image.
///
- public async Task GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default)
+ public async Task GenerateAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken = default)
{
- return new OnnxImage(await GenerateInternalAsync(modelOptions, inputImage, cancellationToken), ImageNormalizeType.ZeroToOne);
+ if (!_pipelines.TryGetValue(modelSet, out var pipeline))
+ throw new Exception("Pipeline not found or is unsupported");
+
+ return await pipeline.RunAsync(inputImage, cancellationToken);
}
///
- /// Generates an upscaled image of the source provided.
+ /// Generates the upscaled video.
///
- /// The model options.
- /// The input image.
- private async Task> GenerateInternalAsync(UpscaleModelSet modelSet, OnnxImage inputImage, CancellationToken cancellationToken)
+ /// The model set.
+ /// The input video.
+ /// The cancellation token.
+ ///
+ /// Pipeline not found or is unsupported
+ public async Task GenerateAsync(UpscaleModelSet modelSet, OnnxVideo inputVideo, CancellationToken cancellationToken = default)
{
if (!_pipelines.TryGetValue(modelSet, out var pipeline))
throw new Exception("Pipeline not found or is unsupported");
- return await pipeline.RunAsync(inputImage, cancellationToken);
+ return await pipeline.RunAsync(inputVideo, cancellationToken);
}
+ ///
+ /// Creates the pipeline.
+ ///
+ /// The model set.
+ ///
private ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet)
{
return ImageUpscalePipeline.CreatePipeline(modelSet, _logger);
}
-
}
}
\ No newline at end of file