Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Image Batch Processing #23

Merged
merged 6 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions OnnxStack.Console/Examples/StableDebug.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Services;
using SixLabors.ImageSharp;
using System.Diagnostics;

Expand Down Expand Up @@ -37,11 +37,11 @@ public async Task RunAsync()
{
Prompt = prompt,
NegativePrompt = negativePrompt,
SchedulerType = SchedulerType.LMS
};

var schedulerOptions = new SchedulerOptions
{
SchedulerType = SchedulerType.LMS,
Seed = 624461087,
//Seed = Random.Shared.Next(),
GuidanceScale = 8,
Expand All @@ -54,9 +54,9 @@ public async Task RunAsync()
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModel(model);

foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
{
promptOptions.SchedulerType = schedulerType;
schedulerOptions.SchedulerType = schedulerType;
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
await GenerateImage(model, promptOptions, schedulerOptions);
}
Expand All @@ -72,12 +72,12 @@ public async Task RunAsync()
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
{
var timestamp = Stopwatch.GetTimestamp();
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
if (result is not null)
{
await result.SaveAsPngAsync(outputFilename);
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
return true;
}
Expand Down
62 changes: 22 additions & 40 deletions OnnxStack.Console/Examples/StableDiffusionBatch.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using OnnxStack.Core;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Helpers;
using OnnxStack.StableDiffusion;
using SixLabors.ImageSharp;
using OnnxStack.StableDiffusion.Helpers;

namespace OnnxStack.Console.Runner
{
Expand Down Expand Up @@ -31,68 +31,50 @@ public async Task RunAsync()

while (true)
{
OutputHelpers.WriteConsole("Please type a prompt and press ENTER", ConsoleColor.Yellow);
var prompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan);

OutputHelpers.WriteConsole("Please type a negative prompt and press ENTER (optional)", ConsoleColor.Yellow);
var negativePrompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan);

OutputHelpers.WriteConsole("Please enter a batch count and press ENTER", ConsoleColor.Yellow);
var batch = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
int.TryParse(batch, out var batchCount);
batchCount = Math.Max(1, batchCount);

var promptOptions = new PromptOptions
{
Prompt = prompt,
NegativePrompt = negativePrompt,
BatchCount = batchCount
Prompt = "Photo of a cat"
};

var schedulerOptions = new SchedulerOptions
{
Seed = Random.Shared.Next(),

GuidanceScale = 8,
InferenceSteps = 22,
InferenceSteps = 20,
Strength = 0.6f
};

var batchOptions = new BatchOptions
{
BatchType = BatchOptionType.Scheduler
};

foreach (var model in _stableDiffusionService.Models)
{
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModel(model);

foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
var batchIndex = 0;
var callback = (int batch, int batchCount, int step, int steps) =>
{
batchIndex = batch;
OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan);
};

await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
{
promptOptions.SchedulerType = schedulerType;
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
await GenerateImage(model, promptOptions, schedulerOptions);
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
var image = result.ImageResult.ToImage();
await image.SaveAsPngAsync(outputFilename);
OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
}

OutputHelpers.WriteConsole($"Unloading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.UnloadModel(model);
}
}
}

private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
{

var result = await _stableDiffusionService.GenerateAsync(model, prompt, options);
if (result == null)
return false;

var imageTensors = result.Split(prompt.BatchCount);
for (int i = 0; i < imageTensors.Length; i++)
{
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{i}.png");
var image = imageTensors[i].ToImage();
await image.SaveAsPngAsync(outputFilename);
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
}

return true;
}
}
}
11 changes: 5 additions & 6 deletions OnnxStack.Console/Examples/StableDiffusionExample.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using OnnxStack.Core;
using OnnxStack.StableDiffusion;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using SixLabors.ImageSharp;

namespace OnnxStack.Console.Runner
Expand Down Expand Up @@ -53,9 +52,9 @@ public async Task RunAsync()
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
await _stableDiffusionService.LoadModel(model);

foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
{
promptOptions.SchedulerType = schedulerType;
schedulerOptions.SchedulerType = schedulerType;
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
await GenerateImage(model, promptOptions, schedulerOptions);
}
Expand All @@ -68,13 +67,13 @@ public async Task RunAsync()

private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
{
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
if (result == null)
return false;

await result.SaveAsPngAsync(outputFilename);
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
return true;
}
}
Expand Down
11 changes: 5 additions & 6 deletions OnnxStack.Console/Examples/StableDiffusionGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using OnnxStack.Core;
using OnnxStack.StableDiffusion;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using SixLabors.ImageSharp;
using System.Collections.ObjectModel;

Expand Down Expand Up @@ -48,9 +47,9 @@ public async Task RunAsync()
{
Seed = Random.Shared.Next()
};
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
{
promptOptions.SchedulerType = schedulerType;
schedulerOptions.SchedulerType = schedulerType;
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
await GenerateImage(model, promptOptions, schedulerOptions, generationPrompt.Key);
}
Expand All @@ -65,13 +64,13 @@ public async Task RunAsync()

private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options, string key)
{
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{key}.png");
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png");
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
if (result == null)
return false;

await result.SaveAsPngAsync(outputFilename);
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
return true;
}

Expand Down
28 changes: 0 additions & 28 deletions OnnxStack.Console/Helpers.cs

This file was deleted.

49 changes: 49 additions & 0 deletions OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Models;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
using System;
Expand Down Expand Up @@ -83,5 +84,53 @@ public interface IStableDiffusionService
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
Task<Stream> GenerateAsStreamAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);

/// <summary>
/// Generates a batch of StableDiffusion image using the prompt and options provided.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="batchOptions">The batch options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<BatchResult> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);

/// <summary>
/// Generates a batch of StableDiffusion image using the prompt and options provided.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="batchOptions">The batch options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);

/// <summary>
/// Generates a batch of StableDiffusion image using the prompt and options provided.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="batchOptions">The batch options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);

/// <summary>
/// Generates a batch of StableDiffusion image using the prompt and options provided.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="batchOptions">The batch options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
}
}
12 changes: 12 additions & 0 deletions OnnxStack.StableDiffusion/Config/BatchOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using OnnxStack.StableDiffusion.Enums;

namespace OnnxStack.StableDiffusion.Config
{
public record BatchOptions
{
public BatchOptionType BatchType { get; set; }
public float ValueTo { get; set; }
public float ValueFrom { get; set; }
public float Increment { get; set; } = 1f;
}
}
1 change: 0 additions & 1 deletion OnnxStack.StableDiffusion/Config/PromptOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ public class PromptOptions

[StringLength(512)]
public string NegativePrompt { get; set; }
public SchedulerType SchedulerType { get; set; }

public int BatchCount { get; set; } = 1;

Expand Down
7 changes: 6 additions & 1 deletion OnnxStack.StableDiffusion/Config/SchedulerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

namespace OnnxStack.StableDiffusion.Config
{
public class SchedulerOptions
public record SchedulerOptions
{
/// <summary>
/// Gets or sets the type of scheduler.
/// </summary>
public SchedulerType SchedulerType { get; set; }

/// <summary>
/// Gets or sets the height.
/// </summary>
Expand Down
15 changes: 15 additions & 0 deletions OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -33,5 +35,18 @@ public interface IDiffuser
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);


/// <summary>
/// Runs the stable diffusion batch loop
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="batchOptions">The batch options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
}
}
Loading