diff --git a/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs deleted file mode 100644 index 40551b8..0000000 --- a/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs +++ /dev/null @@ -1,51 +0,0 @@ -using OnnxStack.Core.Image; -using OnnxStack.FeatureExtractor.Pipelines; -using System.Diagnostics; - -namespace OnnxStack.Console.Runner -{ - public sealed class BackgroundRemovalImageExample : IExampleRunner - { - private readonly string _outputDirectory; - - public BackgroundRemovalImageExample() - { - _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample"); - Directory.CreateDirectory(_outputDirectory); - } - - public int Index => 20; - - public string Name => "Image Background Removal Example"; - - public string Description => "Remove a background from an image"; - - /// - /// ControlNet Example - /// - public async Task RunAsync() - { - OutputHelpers.WriteConsole("Please enter an image file path and press ENTER", ConsoleColor.Yellow); - var imageFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan); - - var timestamp = Stopwatch.GetTimestamp(); - - OutputHelpers.WriteConsole($"Load Image", ConsoleColor.Gray); - var inputImage = await OnnxImage.FromFileAsync(imageFile); - - OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray); - var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024); - - OutputHelpers.WriteConsole($"Run Pipeline", ConsoleColor.Gray); - var imageFeature = await pipeline.RunAsync(inputImage); - - OutputHelpers.WriteConsole($"Save Image", ConsoleColor.Gray); - await imageFeature.SaveAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png")); - - OutputHelpers.WriteConsole($"Unload pipeline", ConsoleColor.Gray); - await pipeline.UnloadAsync(); - - OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow); - } - } -} diff --git a/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs deleted file mode 100644 index 788b8d1..0000000 --- a/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs +++ /dev/null @@ -1,54 +0,0 @@ -using OnnxStack.Core.Video; -using OnnxStack.FeatureExtractor.Pipelines; -using System.Diagnostics; - -namespace OnnxStack.Console.Runner -{ - public sealed class BackgroundRemovalVideoExample : IExampleRunner - { - private readonly string _outputDirectory; - - public BackgroundRemovalVideoExample() - { - _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample"); - Directory.CreateDirectory(_outputDirectory); - } - - public int Index => 21; - - public string Name => "Video Background Removal Example"; - - public string Description => "Remove a background from an video"; - - public async Task RunAsync() - { - OutputHelpers.WriteConsole("Please enter an video/gif file path and press ENTER", ConsoleColor.Yellow); - var videoFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan); - - var timestamp = Stopwatch.GetTimestamp(); - - OutputHelpers.WriteConsole($"Read Video", ConsoleColor.Gray); - var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile); - - OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray); - var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024); - - OutputHelpers.WriteConsole($"Load Pipeline", ConsoleColor.Gray); - await pipeline.LoadAsync(); - - OutputHelpers.WriteConsole($"Create Video Stream", ConsoleColor.Gray); - var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate); - - OutputHelpers.WriteConsole($"Create Pipeline Stream", ConsoleColor.Gray); - var pipelineStream = pipeline.RunAsync(videoStream); - - OutputHelpers.WriteConsole($"Write Video Stream", ConsoleColor.Gray); - await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4"), true); - - OutputHelpers.WriteConsole($"Unload", ConsoleColor.Gray); - await pipeline.UnloadAsync(); - - OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow); - } - } -} diff --git a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs index a3bd0e4..2341f04 100644 --- a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs +++ b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs @@ -35,7 +35,7 @@ public async Task RunAsync() var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); // Create Annotation pipeline - var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", true); + var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true); // Create Depth Image var controlImage = await annotationPipeline.RunAsync(inputImage); diff --git a/OnnxStack.Console/Examples/FeatureExtractorExample.cs b/OnnxStack.Console/Examples/FeatureExtractorExample.cs index c45e7b6..e7a1ee3 100644 --- a/OnnxStack.Console/Examples/FeatureExtractorExample.cs +++ b/OnnxStack.Console/Examples/FeatureExtractorExample.cs @@ -37,10 +37,8 @@ public async Task RunAsync() { FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\canny.onnx"), FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\hed.onnx"), - FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", true), - - // FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\depth-anything-large-hf\\onnx\\model.onnx", normalize: true, sampleSize: 504), - // FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\sentis-MiDaS\\dpt_beit_large_512.onnx", normalize: true, sampleSize: 384), + FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true, inputResizeMode: ImageResizeMode.Stretch), + FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024, setOutputToInputAlpha: true, inputResizeMode: ImageResizeMode.Stretch) }; foreach (var pipeline in pipelines) @@ -49,7 +47,7 @@ public async Task RunAsync() OutputHelpers.WriteConsole($"Load pipeline`{pipeline.Name}`", ConsoleColor.Cyan); // Run Image Pipeline - var imageFeature = await pipeline.RunAsync(inputImage); + var imageFeature = await pipeline.RunAsync(inputImage.Clone()); OutputHelpers.WriteConsole($"Generating image", ConsoleColor.Cyan); diff --git a/OnnxStack.Core/Extensions/Extensions.cs b/OnnxStack.Core/Extensions/Extensions.cs index dafd961..3ea62c6 100644 --- a/OnnxStack.Core/Extensions/Extensions.cs +++ b/OnnxStack.Core/Extensions/Extensions.cs @@ -1,4 +1,5 @@ using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core.Config; using System; using System.Collections.Concurrent; @@ -244,5 +245,26 @@ public static long[] ToLong(this int[] array) { return Array.ConvertAll(array, Convert.ToInt64); } + + + /// + /// Normalize the data using Min-Max scaling to ensure all values are in the range [0, 1]. + /// + /// The values. + public static void NormalizeMinMax(this Span values) + { + float min = float.PositiveInfinity, max = float.NegativeInfinity; + foreach (var val in values) + { + if (min > val) min = val; + if (max < val) max = val; + } + + var range = max - min; + for (var i = 0; i < values.Length; i++) + { + values[i] = (values[i] - min) / range; + } + } } } diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index 6e93467..ffcb29b 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -286,19 +286,7 @@ public static DenseTensor Repeat(this DenseTensor tensor1, int cou /// The tensor. public static void NormalizeMinMax(this DenseTensor tensor) { - var values = tensor.Buffer.Span; - float min = float.PositiveInfinity, max = float.NegativeInfinity; - foreach (var val in values) - { - if (min > val) min = val; - if (max < val) max = val; - } - - var range = max - min; - for (var i = 0; i < values.Length; i++) - { - values[i] = (values[i] - min) / range; - } + tensor.Buffer.Span.NormalizeMinMax(); } diff --git a/OnnxStack.Core/Image/Extensions.cs b/OnnxStack.Core/Image/Extensions.cs index 8a74b4e..e75d86a 100644 --- a/OnnxStack.Core/Image/Extensions.cs +++ b/OnnxStack.Core/Image/Extensions.cs @@ -1,6 +1,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using SixLabors.ImageSharp; using SixLabors.ImageSharp.PixelFormats; +using SixLabors.ImageSharp.Processing; namespace OnnxStack.Core.Image { @@ -29,11 +30,27 @@ public static OnnxImage ToImageMask(this DenseTensor imageTensor) } } + + public static ResizeMode ToResizeMode(this ImageResizeMode resizeMode) + { + return resizeMode switch + { + ImageResizeMode.Stretch => ResizeMode.Stretch, + _ => ResizeMode.Crop + }; + } + } public enum ImageNormalizeType { ZeroToOne = 0, - OneToOne = 1, + OneToOne = 1 + } + + public enum ImageResizeMode + { + Crop = 0, + Stretch = 1 } } diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs index 643e514..a9d9cb2 100644 --- a/OnnxStack.Core/Image/OnnxImage.cs +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -230,15 +230,27 @@ public DenseTensor GetImageTensor(ImageNormalizeType normalizeType = Imag /// Type of the normalize. /// The channels. /// - public DenseTensor GetImageTensor(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) + public DenseTensor GetImageTensor(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3, ImageResizeMode resizeMode = ImageResizeMode.Crop) { if (height > 0 && width > 0) - Resize(height, width); + Resize(height, width, resizeMode); return GetImageTensor(normalizeType, channels); } + /// + /// Gets the image as tensor asynchronously. + /// + /// Type of the normalize. + /// The channels. + /// + public Task> GetImageTensorAsync(ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) + { + return Task.Run(() => GetImageTensor(normalizeType, channels)); + } + + /// /// Gets the image as tensor asynchronously. /// @@ -247,9 +259,9 @@ public DenseTensor GetImageTensor(int height, int width, ImageNormalizeTy /// Type of the normalize. /// The channels. /// - public Task> GetImageTensorAsync(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3) + public Task> GetImageTensorAsync(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3, ImageResizeMode resizeMode = ImageResizeMode.Crop) { - return Task.Run(() => GetImageTensor(height, width, normalizeType, channels)); + return Task.Run(() => GetImageTensor(height, width, normalizeType, channels, resizeMode)); } @@ -259,20 +271,21 @@ public Task> GetImageTensorAsync(int height, int width, Image /// The height. /// The width. /// The resize mode. - public void Resize(int height, int width, ResizeMode resizeMode = ResizeMode.Crop) + public void Resize(int height, int width, ImageResizeMode resizeMode = ImageResizeMode.Crop) { _imageData.Mutate(x => { x.Resize(new ResizeOptions { Size = new Size(width, height), - Mode = resizeMode, + Mode = resizeMode.ToResizeMode(), Sampler = KnownResamplers.Lanczos8, Compand = true }); }); } + public OnnxImage Clone() { return new OnnxImage(_imageData); diff --git a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs index 2760668..e63b159 100644 --- a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs +++ b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs @@ -1,56 +1,52 @@ using Microsoft.ML.OnnxRuntime; using OnnxStack.Core.Config; +using OnnxStack.Core.Image; using OnnxStack.Core.Model; namespace OnnxStack.FeatureExtractor.Common { public class FeatureExtractorModel : OnnxModelSession { - private readonly int _sampleSize; - private readonly bool _normalize; - private readonly int _channels; + private readonly FeatureExtractorModelConfig _configuration; public FeatureExtractorModel(FeatureExtractorModelConfig configuration) : base(configuration) { - _sampleSize = configuration.SampleSize; - _normalize = configuration.Normalize; - _channels = configuration.Channels; + _configuration = configuration; } - public int SampleSize => _sampleSize; - - public bool Normalize => _normalize; - - public int Channels => _channels; + public int OutputChannels => _configuration.OutputChannels; + public int SampleSize => _configuration.SampleSize; + public bool NormalizeOutputTensor => _configuration.NormalizeOutputTensor; + public bool SetOutputToInputAlpha => _configuration.SetOutputToInputAlpha; + public ImageResizeMode InputResizeMode => _configuration.InputResizeMode; + public ImageNormalizeType InputNormalization => _configuration.NormalizeInputTensor; public static FeatureExtractorModel Create(FeatureExtractorModelConfig configuration) { return new FeatureExtractorModel(configuration); } - public static FeatureExtractorModel Create(string modelFile, bool normalize = false, int sampleSize = 512, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) { var configuration = new FeatureExtractorModelConfig { - SampleSize = sampleSize, - Normalize = normalize, - Channels = channels, DeviceId = deviceId, ExecutionProvider = executionProvider, ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, InterOpNumThreads = 0, IntraOpNumThreads = 0, - OnnxModelPath = modelFile + OnnxModelPath = modelFile, + + + SampleSize = sampleSize, + OutputChannels = outputChannels, + NormalizeOutputTensor = normalizeOutputTensor, + SetOutputToInputAlpha = setOutputToInputAlpha, + NormalizeInputTensor = normalizeInputTensor, + InputResizeMode = inputResizeMode }; return new FeatureExtractorModel(configuration); } } - - public record FeatureExtractorModelConfig : OnnxModelConfig - { - public int SampleSize { get; set; } - public bool Normalize { get; set; } - public int Channels { get; set; } - } } diff --git a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs new file mode 100644 index 0000000..50af0f9 --- /dev/null +++ b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs @@ -0,0 +1,15 @@ +using OnnxStack.Core.Config; +using OnnxStack.Core.Image; + +namespace OnnxStack.FeatureExtractor.Common +{ + public record FeatureExtractorModelConfig : OnnxModelConfig + { + public int SampleSize { get; set; } + public int OutputChannels { get; set; } + public bool NormalizeOutputTensor { get; set; } + public bool SetOutputToInputAlpha { get; set; } + public ImageResizeMode InputResizeMode { get; set; } + public ImageNormalizeType NormalizeInputTensor { get; set; } + } +} diff --git a/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs deleted file mode 100644 index 7cd69b8..0000000 --- a/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs +++ /dev/null @@ -1,205 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core; -using OnnxStack.Core.Config; -using OnnxStack.Core.Image; -using OnnxStack.Core.Model; -using OnnxStack.Core.Video; -using OnnxStack.FeatureExtractor.Common; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.FeatureExtractor.Pipelines -{ - public class BackgroundRemovalPipeline - { - private readonly string _name; - private readonly ILogger _logger; - private readonly FeatureExtractorModel _model; - - /// - /// Initializes a new instance of the class. - /// - /// The name. - /// The model. - /// The logger. - public BackgroundRemovalPipeline(string name, FeatureExtractorModel model, ILogger logger = default) - { - _name = name; - _logger = logger; - _model = model; - } - - - /// - /// Gets the name. - /// - /// - public string Name => _name; - - - /// - /// Loads the model. - /// - /// - public Task LoadAsync() - { - return _model.LoadAsync(); - } - - - /// - /// Unloads the models. - /// - public async Task UnloadAsync() - { - await Task.Yield(); - _model?.Dispose(); - } - - - /// - /// Generates the background removal image result - /// - /// The input image. - /// - public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) - { - var timestamp = _logger?.LogBegin("Removing video background..."); - var result = await RunInternalAsync(inputImage, cancellationToken); - _logger?.LogEnd("Removing video background complete.", timestamp); - return result; - } - - - /// - /// Generates the background removal video result - /// - /// The input video. - /// - public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default) - { - var timestamp = _logger?.LogBegin("Removing video background..."); - var videoFrames = new List(); - foreach (var videoFrame in video.Frames) - { - videoFrames.Add(await RunInternalAsync(videoFrame, cancellationToken)); - } - _logger?.LogEnd("Removing video background complete.", timestamp); - return new OnnxVideo(video.Info with - { - Height = videoFrames[0].Height, - Width = videoFrames[0].Width, - }, videoFrames); - } - - - /// - /// Generates the background removal video stream - /// - /// The image frames. - /// The cancellation token. - /// - public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var timestamp = _logger?.LogBegin("Extracting video stream features..."); - await foreach (var imageFrame in imageFrames) - { - yield return await RunInternalAsync(imageFrame, cancellationToken); - } - _logger?.LogEnd("Extracting video stream features complete.", timestamp); - } - - - /// - /// Runs the pipeline - /// - /// The input image. - /// The cancellation token. - /// - private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) - { - var souceImageTenssor = await inputImage.GetImageTensorAsync(_model.SampleSize, _model.SampleSize, ImageNormalizeType.ZeroToOne); - var metadata = await _model.GetMetadataAsync(); - cancellationToken.ThrowIfCancellationRequested(); - var outputShape = new[] { 1, _model.Channels, _model.SampleSize, _model.SampleSize }; - var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..]; - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(souceImageTenssor); - inferenceParameters.AddOutputBuffer(outputBuffer); - - var results = await _model.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) - { - cancellationToken.ThrowIfCancellationRequested(); - - var imageTensor = AddAlphaChannel(souceImageTenssor, result.GetTensorDataAsSpan()); - return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne); - } - } - } - - - /// - /// Adds an alpha channel to the RGB tensor. - /// - /// The source image. - /// The alpha channel. - /// - private static DenseTensor AddAlphaChannel(DenseTensor sourceImage, ReadOnlySpan alphaChannel) - { - var resultTensor = new DenseTensor(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] }); - sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span); - alphaChannel.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span); - return resultTensor; - } - - - /// - /// Creates the pipeline from a FeatureExtractorModelSet. - /// - /// The model set. - /// The logger. - /// - public static BackgroundRemovalPipeline CreatePipeline(FeatureExtractorModelSet modelSet, ILogger logger = default) - { - var model = new FeatureExtractorModel(modelSet.FeatureExtractorConfig.ApplyDefaults(modelSet)); - return new BackgroundRemovalPipeline(modelSet.Name, model, logger); - } - - - /// - /// Creates the pipeline from the specified file. - /// - /// The model file. - /// The device identifier. - /// The execution provider. - /// The logger. - /// - public static BackgroundRemovalPipeline CreatePipeline(string modelFile, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) - { - var name = Path.GetFileNameWithoutExtension(modelFile); - var configuration = new FeatureExtractorModelSet - { - Name = name, - IsEnabled = true, - DeviceId = deviceId, - ExecutionProvider = executionProvider, - FeatureExtractorConfig = new FeatureExtractorModelConfig - { - OnnxModelPath = modelFile, - SampleSize = sampleSize, - Normalize = false, - Channels = 1 - } - }; - return CreatePipeline(configuration, logger); - } - } -} diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs index dc5d913..5ca8ee7 100644 --- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs +++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Image; @@ -118,31 +119,61 @@ public async IAsyncEnumerable RunAsync(IAsyncEnumerable im /// private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { - var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne); + var originalWidth = inputImage.Width; + var originalHeight = inputImage.Height; + var inputTensor = _featureExtractorModel.SampleSize <= 0 + ? await inputImage.GetImageTensorAsync(_featureExtractorModel.InputNormalization) + : await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, _featureExtractorModel.InputNormalization, resizeMode: _featureExtractorModel.InputResizeMode); var metadata = await _featureExtractorModel.GetMetadataAsync(); cancellationToken.ThrowIfCancellationRequested(); - var outputShape = new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize }; + var outputShape = new[] { 1, _featureExtractorModel.OutputChannels, inputTensor.Dimensions[2], inputTensor.Dimensions[3] }; var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..]; using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { - inferenceParameters.AddInputTensor(controlImage); + inferenceParameters.AddInputTensor(inputTensor); inferenceParameters.AddOutputBuffer(outputBuffer); - var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters); - using (var result = results.First()) + var inferenceResults = await _featureExtractorModel.RunInferenceAsync(inferenceParameters); + using (var inferenceResult = inferenceResults.First()) { cancellationToken.ThrowIfCancellationRequested(); - var resultTensor = result.ToDenseTensor(outputShape); - if (_featureExtractorModel.Normalize) - resultTensor.NormalizeMinMax(); + var outputTensor = inferenceResult.ToDenseTensor(outputShape); + if (_featureExtractorModel.NormalizeOutputTensor) + outputTensor.NormalizeMinMax(); - return resultTensor.ToImageMask(); + var imageResult = default(OnnxImage); + if (_featureExtractorModel.SetOutputToInputAlpha) + imageResult = new OnnxImage(AddAlphaChannel(inputTensor, outputTensor), _featureExtractorModel.InputNormalization); + else if (_featureExtractorModel.OutputChannels >= 3) + imageResult = new OnnxImage(outputTensor, _featureExtractorModel.InputNormalization); + else + imageResult = outputTensor.ToImageMask(); + + if (_featureExtractorModel.InputResizeMode == ImageResizeMode.Stretch && (imageResult.Width != originalWidth || imageResult.Height != originalHeight)) + imageResult.Resize(originalHeight, originalWidth, _featureExtractorModel.InputResizeMode); + + return imageResult; } } } + /// + /// Adds an alpha channel to the RGB tensor. + /// + /// The source image. + /// The alpha channel. + /// + private static DenseTensor AddAlphaChannel(DenseTensor sourceImage, DenseTensor alphaChannel) + { + var resultTensor = new DenseTensor(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] }); + sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span); + alphaChannel.Buffer.Span.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span); + return resultTensor; + } + + /// /// Creates the pipeline from a FeatureExtractorModelSet. /// @@ -157,14 +188,19 @@ public static FeatureExtractorPipeline CreatePipeline(FeatureExtractorModelSet m /// - /// Creates the pipeline from the specified file. + /// Creates the pipeline from the specified arguments. /// /// The model file. + /// Size of the sample. + /// The channels. + /// if set to true [normalize output tensor]. + /// The normalize input tensor. + /// if set to true [set output to input alpha]. /// The device identifier. /// The execution provider. /// The logger. /// - public static FeatureExtractorPipeline CreatePipeline(string modelFile, bool normalize = false, int sampleSize = 512, int channels = 1, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) + public static FeatureExtractorPipeline CreatePipeline(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) { var name = Path.GetFileNameWithoutExtension(modelFile); var configuration = new FeatureExtractorModelSet @@ -177,8 +213,11 @@ public static FeatureExtractorPipeline CreatePipeline(string modelFile, bool nor { OnnxModelPath = modelFile, SampleSize = sampleSize, - Normalize = normalize, - Channels = channels + OutputChannels = outputChannels, + NormalizeOutputTensor = normalizeOutputTensor, + SetOutputToInputAlpha = setOutputToInputAlpha, + NormalizeInputTensor = normalizeInputTensor, + InputResizeMode = inputResizeMode } }; return CreatePipeline(configuration, logger);