diff --git a/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs
deleted file mode 100644
index 40551b8..0000000
--- a/OnnxStack.Console/Examples/BackgroundRemovalImageExample.cs
+++ /dev/null
@@ -1,51 +0,0 @@
-using OnnxStack.Core.Image;
-using OnnxStack.FeatureExtractor.Pipelines;
-using System.Diagnostics;
-
-namespace OnnxStack.Console.Runner
-{
- public sealed class BackgroundRemovalImageExample : IExampleRunner
- {
- private readonly string _outputDirectory;
-
- public BackgroundRemovalImageExample()
- {
- _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample");
- Directory.CreateDirectory(_outputDirectory);
- }
-
- public int Index => 20;
-
- public string Name => "Image Background Removal Example";
-
- public string Description => "Remove a background from an image";
-
- ///
- /// ControlNet Example
- ///
- public async Task RunAsync()
- {
- OutputHelpers.WriteConsole("Please enter an image file path and press ENTER", ConsoleColor.Yellow);
- var imageFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
-
- var timestamp = Stopwatch.GetTimestamp();
-
- OutputHelpers.WriteConsole($"Load Image", ConsoleColor.Gray);
- var inputImage = await OnnxImage.FromFileAsync(imageFile);
-
- OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray);
- var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024);
-
- OutputHelpers.WriteConsole($"Run Pipeline", ConsoleColor.Gray);
- var imageFeature = await pipeline.RunAsync(inputImage);
-
- OutputHelpers.WriteConsole($"Save Image", ConsoleColor.Gray);
- await imageFeature.SaveAsync(Path.Combine(_outputDirectory, $"{pipeline.Name}.png"));
-
- OutputHelpers.WriteConsole($"Unload pipeline", ConsoleColor.Gray);
- await pipeline.UnloadAsync();
-
- OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
- }
- }
-}
diff --git a/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs b/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs
deleted file mode 100644
index 788b8d1..0000000
--- a/OnnxStack.Console/Examples/BackgroundRemovalVideoExample.cs
+++ /dev/null
@@ -1,54 +0,0 @@
-using OnnxStack.Core.Video;
-using OnnxStack.FeatureExtractor.Pipelines;
-using System.Diagnostics;
-
-namespace OnnxStack.Console.Runner
-{
- public sealed class BackgroundRemovalVideoExample : IExampleRunner
- {
- private readonly string _outputDirectory;
-
- public BackgroundRemovalVideoExample()
- {
- _outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", "BackgroundRemovalExample");
- Directory.CreateDirectory(_outputDirectory);
- }
-
- public int Index => 21;
-
- public string Name => "Video Background Removal Example";
-
- public string Description => "Remove a background from an video";
-
- public async Task RunAsync()
- {
- OutputHelpers.WriteConsole("Please enter an video/gif file path and press ENTER", ConsoleColor.Yellow);
- var videoFile = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
-
- var timestamp = Stopwatch.GetTimestamp();
-
- OutputHelpers.WriteConsole($"Read Video", ConsoleColor.Gray);
- var videoInfo = await VideoHelper.ReadVideoInfoAsync(videoFile);
-
- OutputHelpers.WriteConsole($"Create Pipeline", ConsoleColor.Gray);
- var pipeline = BackgroundRemovalPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024);
-
- OutputHelpers.WriteConsole($"Load Pipeline", ConsoleColor.Gray);
- await pipeline.LoadAsync();
-
- OutputHelpers.WriteConsole($"Create Video Stream", ConsoleColor.Gray);
- var videoStream = VideoHelper.ReadVideoStreamAsync(videoFile, videoInfo.FrameRate);
-
- OutputHelpers.WriteConsole($"Create Pipeline Stream", ConsoleColor.Gray);
- var pipelineStream = pipeline.RunAsync(videoStream);
-
- OutputHelpers.WriteConsole($"Write Video Stream", ConsoleColor.Gray);
- await VideoHelper.WriteVideoStreamAsync(videoInfo, pipelineStream, Path.Combine(_outputDirectory, $"Result.mp4"), true);
-
- OutputHelpers.WriteConsole($"Unload", ConsoleColor.Gray);
- await pipeline.UnloadAsync();
-
- OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
- }
- }
-}
diff --git a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs
index a3bd0e4..2341f04 100644
--- a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs
+++ b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs
@@ -35,7 +35,7 @@ public async Task RunAsync()
var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp");
// Create Annotation pipeline
- var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", true);
+ var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true);
// Create Depth Image
var controlImage = await annotationPipeline.RunAsync(inputImage);
diff --git a/OnnxStack.Console/Examples/FeatureExtractorExample.cs b/OnnxStack.Console/Examples/FeatureExtractorExample.cs
index c45e7b6..e7a1ee3 100644
--- a/OnnxStack.Console/Examples/FeatureExtractorExample.cs
+++ b/OnnxStack.Console/Examples/FeatureExtractorExample.cs
@@ -37,10 +37,8 @@ public async Task RunAsync()
{
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\canny.onnx"),
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\hed.onnx"),
- FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", true),
-
- // FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\depth-anything-large-hf\\onnx\\model.onnx", normalize: true, sampleSize: 504),
- // FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\sentis-MiDaS\\dpt_beit_large_512.onnx", normalize: true, sampleSize: 384),
+ FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true, inputResizeMode: ImageResizeMode.Stretch),
+ FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024, setOutputToInputAlpha: true, inputResizeMode: ImageResizeMode.Stretch)
};
foreach (var pipeline in pipelines)
@@ -49,7 +47,7 @@ public async Task RunAsync()
OutputHelpers.WriteConsole($"Load pipeline`{pipeline.Name}`", ConsoleColor.Cyan);
// Run Image Pipeline
- var imageFeature = await pipeline.RunAsync(inputImage);
+ var imageFeature = await pipeline.RunAsync(inputImage.Clone());
OutputHelpers.WriteConsole($"Generating image", ConsoleColor.Cyan);
diff --git a/OnnxStack.Core/Extensions/Extensions.cs b/OnnxStack.Core/Extensions/Extensions.cs
index dafd961..3ea62c6 100644
--- a/OnnxStack.Core/Extensions/Extensions.cs
+++ b/OnnxStack.Core/Extensions/Extensions.cs
@@ -1,4 +1,5 @@
using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Config;
using System;
using System.Collections.Concurrent;
@@ -244,5 +245,26 @@ public static long[] ToLong(this int[] array)
{
return Array.ConvertAll(array, Convert.ToInt64);
}
+
+
+ ///
+ /// Normalize the data using Min-Max scaling to ensure all values are in the range [0, 1].
+ ///
+ /// The values.
+ public static void NormalizeMinMax(this Span values)
+ {
+ float min = float.PositiveInfinity, max = float.NegativeInfinity;
+ foreach (var val in values)
+ {
+ if (min > val) min = val;
+ if (max < val) max = val;
+ }
+
+ var range = max - min;
+ for (var i = 0; i < values.Length; i++)
+ {
+ values[i] = (values[i] - min) / range;
+ }
+ }
}
}
diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs
index 6e93467..ffcb29b 100644
--- a/OnnxStack.Core/Extensions/TensorExtension.cs
+++ b/OnnxStack.Core/Extensions/TensorExtension.cs
@@ -286,19 +286,7 @@ public static DenseTensor Repeat(this DenseTensor tensor1, int cou
/// The tensor.
public static void NormalizeMinMax(this DenseTensor tensor)
{
- var values = tensor.Buffer.Span;
- float min = float.PositiveInfinity, max = float.NegativeInfinity;
- foreach (var val in values)
- {
- if (min > val) min = val;
- if (max < val) max = val;
- }
-
- var range = max - min;
- for (var i = 0; i < values.Length; i++)
- {
- values[i] = (values[i] - min) / range;
- }
+ tensor.Buffer.Span.NormalizeMinMax();
}
diff --git a/OnnxStack.Core/Image/Extensions.cs b/OnnxStack.Core/Image/Extensions.cs
index 8a74b4e..e75d86a 100644
--- a/OnnxStack.Core/Image/Extensions.cs
+++ b/OnnxStack.Core/Image/Extensions.cs
@@ -1,6 +1,7 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
+using SixLabors.ImageSharp.Processing;
namespace OnnxStack.Core.Image
{
@@ -29,11 +30,27 @@ public static OnnxImage ToImageMask(this DenseTensor imageTensor)
}
}
+
+ public static ResizeMode ToResizeMode(this ImageResizeMode resizeMode)
+ {
+ return resizeMode switch
+ {
+ ImageResizeMode.Stretch => ResizeMode.Stretch,
+ _ => ResizeMode.Crop
+ };
+ }
+
}
public enum ImageNormalizeType
{
ZeroToOne = 0,
- OneToOne = 1,
+ OneToOne = 1
+ }
+
+ public enum ImageResizeMode
+ {
+ Crop = 0,
+ Stretch = 1
}
}
diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs
index 643e514..a9d9cb2 100644
--- a/OnnxStack.Core/Image/OnnxImage.cs
+++ b/OnnxStack.Core/Image/OnnxImage.cs
@@ -230,15 +230,27 @@ public DenseTensor GetImageTensor(ImageNormalizeType normalizeType = Imag
/// Type of the normalize.
/// The channels.
///
- public DenseTensor GetImageTensor(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3)
+ public DenseTensor GetImageTensor(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3, ImageResizeMode resizeMode = ImageResizeMode.Crop)
{
if (height > 0 && width > 0)
- Resize(height, width);
+ Resize(height, width, resizeMode);
return GetImageTensor(normalizeType, channels);
}
+ ///
+ /// Gets the image as tensor asynchronously.
+ ///
+ /// Type of the normalize.
+ /// The channels.
+ ///
+ public Task> GetImageTensorAsync(ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3)
+ {
+ return Task.Run(() => GetImageTensor(normalizeType, channels));
+ }
+
+
///
/// Gets the image as tensor asynchronously.
///
@@ -247,9 +259,9 @@ public DenseTensor GetImageTensor(int height, int width, ImageNormalizeTy
/// Type of the normalize.
/// The channels.
///
- public Task> GetImageTensorAsync(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3)
+ public Task> GetImageTensorAsync(int height, int width, ImageNormalizeType normalizeType = ImageNormalizeType.OneToOne, int channels = 3, ImageResizeMode resizeMode = ImageResizeMode.Crop)
{
- return Task.Run(() => GetImageTensor(height, width, normalizeType, channels));
+ return Task.Run(() => GetImageTensor(height, width, normalizeType, channels, resizeMode));
}
@@ -259,20 +271,21 @@ public Task> GetImageTensorAsync(int height, int width, Image
/// The height.
/// The width.
/// The resize mode.
- public void Resize(int height, int width, ResizeMode resizeMode = ResizeMode.Crop)
+ public void Resize(int height, int width, ImageResizeMode resizeMode = ImageResizeMode.Crop)
{
_imageData.Mutate(x =>
{
x.Resize(new ResizeOptions
{
Size = new Size(width, height),
- Mode = resizeMode,
+ Mode = resizeMode.ToResizeMode(),
Sampler = KnownResamplers.Lanczos8,
Compand = true
});
});
}
+
public OnnxImage Clone()
{
return new OnnxImage(_imageData);
diff --git a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs
index 2760668..e63b159 100644
--- a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs
+++ b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs
@@ -1,56 +1,52 @@
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Config;
+using OnnxStack.Core.Image;
using OnnxStack.Core.Model;
namespace OnnxStack.FeatureExtractor.Common
{
public class FeatureExtractorModel : OnnxModelSession
{
- private readonly int _sampleSize;
- private readonly bool _normalize;
- private readonly int _channels;
+ private readonly FeatureExtractorModelConfig _configuration;
public FeatureExtractorModel(FeatureExtractorModelConfig configuration)
: base(configuration)
{
- _sampleSize = configuration.SampleSize;
- _normalize = configuration.Normalize;
- _channels = configuration.Channels;
+ _configuration = configuration;
}
- public int SampleSize => _sampleSize;
-
- public bool Normalize => _normalize;
-
- public int Channels => _channels;
+ public int OutputChannels => _configuration.OutputChannels;
+ public int SampleSize => _configuration.SampleSize;
+ public bool NormalizeOutputTensor => _configuration.NormalizeOutputTensor;
+ public bool SetOutputToInputAlpha => _configuration.SetOutputToInputAlpha;
+ public ImageResizeMode InputResizeMode => _configuration.InputResizeMode;
+ public ImageNormalizeType InputNormalization => _configuration.NormalizeInputTensor;
public static FeatureExtractorModel Create(FeatureExtractorModelConfig configuration)
{
return new FeatureExtractorModel(configuration);
}
- public static FeatureExtractorModel Create(string modelFile, bool normalize = false, int sampleSize = 512, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
+ public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
{
var configuration = new FeatureExtractorModelConfig
{
- SampleSize = sampleSize,
- Normalize = normalize,
- Channels = channels,
DeviceId = deviceId,
ExecutionProvider = executionProvider,
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
InterOpNumThreads = 0,
IntraOpNumThreads = 0,
- OnnxModelPath = modelFile
+ OnnxModelPath = modelFile,
+
+
+ SampleSize = sampleSize,
+ OutputChannels = outputChannels,
+ NormalizeOutputTensor = normalizeOutputTensor,
+ SetOutputToInputAlpha = setOutputToInputAlpha,
+ NormalizeInputTensor = normalizeInputTensor,
+ InputResizeMode = inputResizeMode
};
return new FeatureExtractorModel(configuration);
}
}
-
- public record FeatureExtractorModelConfig : OnnxModelConfig
- {
- public int SampleSize { get; set; }
- public bool Normalize { get; set; }
- public int Channels { get; set; }
- }
}
diff --git a/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs
new file mode 100644
index 0000000..50af0f9
--- /dev/null
+++ b/OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs
@@ -0,0 +1,15 @@
+using OnnxStack.Core.Config;
+using OnnxStack.Core.Image;
+
+namespace OnnxStack.FeatureExtractor.Common
+{
+ public record FeatureExtractorModelConfig : OnnxModelConfig
+ {
+ public int SampleSize { get; set; }
+ public int OutputChannels { get; set; }
+ public bool NormalizeOutputTensor { get; set; }
+ public bool SetOutputToInputAlpha { get; set; }
+ public ImageResizeMode InputResizeMode { get; set; }
+ public ImageNormalizeType NormalizeInputTensor { get; set; }
+ }
+}
diff --git a/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs
deleted file mode 100644
index 7cd69b8..0000000
--- a/OnnxStack.FeatureExtractor/Pipelines/BackgroundRemovalPipeline.cs
+++ /dev/null
@@ -1,205 +0,0 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.ML.OnnxRuntime.Tensors;
-using OnnxStack.Core;
-using OnnxStack.Core.Config;
-using OnnxStack.Core.Image;
-using OnnxStack.Core.Model;
-using OnnxStack.Core.Video;
-using OnnxStack.FeatureExtractor.Common;
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace OnnxStack.FeatureExtractor.Pipelines
-{
- public class BackgroundRemovalPipeline
- {
- private readonly string _name;
- private readonly ILogger _logger;
- private readonly FeatureExtractorModel _model;
-
- ///
- /// Initializes a new instance of the class.
- ///
- /// The name.
- /// The model.
- /// The logger.
- public BackgroundRemovalPipeline(string name, FeatureExtractorModel model, ILogger logger = default)
- {
- _name = name;
- _logger = logger;
- _model = model;
- }
-
-
- ///
- /// Gets the name.
- ///
- ///
- public string Name => _name;
-
-
- ///
- /// Loads the model.
- ///
- ///
- public Task LoadAsync()
- {
- return _model.LoadAsync();
- }
-
-
- ///
- /// Unloads the models.
- ///
- public async Task UnloadAsync()
- {
- await Task.Yield();
- _model?.Dispose();
- }
-
-
- ///
- /// Generates the background removal image result
- ///
- /// The input image.
- ///
- public async Task RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
- {
- var timestamp = _logger?.LogBegin("Removing video background...");
- var result = await RunInternalAsync(inputImage, cancellationToken);
- _logger?.LogEnd("Removing video background complete.", timestamp);
- return result;
- }
-
-
- ///
- /// Generates the background removal video result
- ///
- /// The input video.
- ///
- public async Task RunAsync(OnnxVideo video, CancellationToken cancellationToken = default)
- {
- var timestamp = _logger?.LogBegin("Removing video background...");
- var videoFrames = new List();
- foreach (var videoFrame in video.Frames)
- {
- videoFrames.Add(await RunInternalAsync(videoFrame, cancellationToken));
- }
- _logger?.LogEnd("Removing video background complete.", timestamp);
- return new OnnxVideo(video.Info with
- {
- Height = videoFrames[0].Height,
- Width = videoFrames[0].Width,
- }, videoFrames);
- }
-
-
- ///
- /// Generates the background removal video stream
- ///
- /// The image frames.
- /// The cancellation token.
- ///
- public async IAsyncEnumerable RunAsync(IAsyncEnumerable imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- var timestamp = _logger?.LogBegin("Extracting video stream features...");
- await foreach (var imageFrame in imageFrames)
- {
- yield return await RunInternalAsync(imageFrame, cancellationToken);
- }
- _logger?.LogEnd("Extracting video stream features complete.", timestamp);
- }
-
-
- ///
- /// Runs the pipeline
- ///
- /// The input image.
- /// The cancellation token.
- ///
- private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
- {
- var souceImageTenssor = await inputImage.GetImageTensorAsync(_model.SampleSize, _model.SampleSize, ImageNormalizeType.ZeroToOne);
- var metadata = await _model.GetMetadataAsync();
- cancellationToken.ThrowIfCancellationRequested();
- var outputShape = new[] { 1, _model.Channels, _model.SampleSize, _model.SampleSize };
- var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..];
- using (var inferenceParameters = new OnnxInferenceParameters(metadata))
- {
- inferenceParameters.AddInputTensor(souceImageTenssor);
- inferenceParameters.AddOutputBuffer(outputBuffer);
-
- var results = await _model.RunInferenceAsync(inferenceParameters);
- using (var result = results.First())
- {
- cancellationToken.ThrowIfCancellationRequested();
-
- var imageTensor = AddAlphaChannel(souceImageTenssor, result.GetTensorDataAsSpan());
- return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne);
- }
- }
- }
-
-
- ///
- /// Adds an alpha channel to the RGB tensor.
- ///
- /// The source image.
- /// The alpha channel.
- ///
- private static DenseTensor AddAlphaChannel(DenseTensor sourceImage, ReadOnlySpan alphaChannel)
- {
- var resultTensor = new DenseTensor(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] });
- sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span);
- alphaChannel.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span);
- return resultTensor;
- }
-
-
- ///
- /// Creates the pipeline from a FeatureExtractorModelSet.
- ///
- /// The model set.
- /// The logger.
- ///
- public static BackgroundRemovalPipeline CreatePipeline(FeatureExtractorModelSet modelSet, ILogger logger = default)
- {
- var model = new FeatureExtractorModel(modelSet.FeatureExtractorConfig.ApplyDefaults(modelSet));
- return new BackgroundRemovalPipeline(modelSet.Name, model, logger);
- }
-
-
- ///
- /// Creates the pipeline from the specified file.
- ///
- /// The model file.
- /// The device identifier.
- /// The execution provider.
- /// The logger.
- ///
- public static BackgroundRemovalPipeline CreatePipeline(string modelFile, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
- {
- var name = Path.GetFileNameWithoutExtension(modelFile);
- var configuration = new FeatureExtractorModelSet
- {
- Name = name,
- IsEnabled = true,
- DeviceId = deviceId,
- ExecutionProvider = executionProvider,
- FeatureExtractorConfig = new FeatureExtractorModelConfig
- {
- OnnxModelPath = modelFile,
- SampleSize = sampleSize,
- Normalize = false,
- Channels = 1
- }
- };
- return CreatePipeline(configuration, logger);
- }
- }
-}
diff --git a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs
index dc5d913..5ca8ee7 100644
--- a/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs
+++ b/OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs
@@ -1,4 +1,5 @@
using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.Core.Image;
@@ -118,31 +119,61 @@ public async IAsyncEnumerable RunAsync(IAsyncEnumerable im
///
private async Task RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
{
- var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne);
+ var originalWidth = inputImage.Width;
+ var originalHeight = inputImage.Height;
+ var inputTensor = _featureExtractorModel.SampleSize <= 0
+ ? await inputImage.GetImageTensorAsync(_featureExtractorModel.InputNormalization)
+ : await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, _featureExtractorModel.InputNormalization, resizeMode: _featureExtractorModel.InputResizeMode);
var metadata = await _featureExtractorModel.GetMetadataAsync();
cancellationToken.ThrowIfCancellationRequested();
- var outputShape = new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize };
+ var outputShape = new[] { 1, _featureExtractorModel.OutputChannels, inputTensor.Dimensions[2], inputTensor.Dimensions[3] };
var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..];
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
- inferenceParameters.AddInputTensor(controlImage);
+ inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer(outputBuffer);
- var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters);
- using (var result = results.First())
+ var inferenceResults = await _featureExtractorModel.RunInferenceAsync(inferenceParameters);
+ using (var inferenceResult = inferenceResults.First())
{
cancellationToken.ThrowIfCancellationRequested();
- var resultTensor = result.ToDenseTensor(outputShape);
- if (_featureExtractorModel.Normalize)
- resultTensor.NormalizeMinMax();
+ var outputTensor = inferenceResult.ToDenseTensor(outputShape);
+ if (_featureExtractorModel.NormalizeOutputTensor)
+ outputTensor.NormalizeMinMax();
- return resultTensor.ToImageMask();
+ var imageResult = default(OnnxImage);
+ if (_featureExtractorModel.SetOutputToInputAlpha)
+ imageResult = new OnnxImage(AddAlphaChannel(inputTensor, outputTensor), _featureExtractorModel.InputNormalization);
+ else if (_featureExtractorModel.OutputChannels >= 3)
+ imageResult = new OnnxImage(outputTensor, _featureExtractorModel.InputNormalization);
+ else
+ imageResult = outputTensor.ToImageMask();
+
+ if (_featureExtractorModel.InputResizeMode == ImageResizeMode.Stretch && (imageResult.Width != originalWidth || imageResult.Height != originalHeight))
+ imageResult.Resize(originalHeight, originalWidth, _featureExtractorModel.InputResizeMode);
+
+ return imageResult;
}
}
}
+ ///
+ /// Adds an alpha channel to the RGB tensor.
+ ///
+ /// The source image.
+ /// The alpha channel.
+ ///
+ private static DenseTensor AddAlphaChannel(DenseTensor sourceImage, DenseTensor alphaChannel)
+ {
+ var resultTensor = new DenseTensor(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] });
+ sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span);
+ alphaChannel.Buffer.Span.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span);
+ return resultTensor;
+ }
+
+
///
/// Creates the pipeline from a FeatureExtractorModelSet.
///
@@ -157,14 +188,19 @@ public static FeatureExtractorPipeline CreatePipeline(FeatureExtractorModelSet m
///
- /// Creates the pipeline from the specified file.
+ /// Creates the pipeline from the specified arguments.
///
/// The model file.
+ /// Size of the sample.
+ /// The channels.
+ /// if set to true [normalize output tensor].
+ /// The normalize input tensor.
+ /// if set to true [set output to input alpha].
/// The device identifier.
/// The execution provider.
/// The logger.
///
- public static FeatureExtractorPipeline CreatePipeline(string modelFile, bool normalize = false, int sampleSize = 512, int channels = 1, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ public static FeatureExtractorPipeline CreatePipeline(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
{
var name = Path.GetFileNameWithoutExtension(modelFile);
var configuration = new FeatureExtractorModelSet
@@ -177,8 +213,11 @@ public static FeatureExtractorPipeline CreatePipeline(string modelFile, bool nor
{
OnnxModelPath = modelFile,
SampleSize = sampleSize,
- Normalize = normalize,
- Channels = channels
+ OutputChannels = outputChannels,
+ NormalizeOutputTensor = normalizeOutputTensor,
+ SetOutputToInputAlpha = setOutputToInputAlpha,
+ NormalizeInputTensor = normalizeInputTensor,
+ InputResizeMode = inputResizeMode
}
};
return CreatePipeline(configuration, logger);