Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit ba6475f

Browse files
authored
Merge pull request #119 from saddam213/API
Refactor StableDiffusionPipeline
2 parents c222d4d + d1922ad commit ba6475f

File tree

5 files changed

+103
-124
lines changed

5 files changed

+103
-124
lines changed

OnnxStack.Core/Video/OnnxVideo.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ public OnnxVideo(VideoInfo info, DenseTensor<float> videoTensor)
4141
}
4242

4343

44+
/// <summary>
45+
/// Initializes a new instance of the <see cref="OnnxVideo"/> class.
46+
/// </summary>
47+
/// <param name="info">The information.</param>
48+
/// <param name="videoTensors">The video tensors.</param>
49+
public OnnxVideo(VideoInfo info, IEnumerable<DenseTensor<float>> videoTensors)
50+
{
51+
_info = info;
52+
_frames = videoTensors
53+
.Select(x => new OnnxImage(x))
54+
.ToList();
55+
}
56+
57+
4458
/// <summary>
4559
/// Gets the height.
4660
/// </summary>
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core.Image;
3+
using OnnxStack.Core.Video;
24
using OnnxStack.StableDiffusion.Config;
35

46
namespace OnnxStack.StableDiffusion.Common
57
{
68
public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor<float> Result);
9+
public record BatchImageResult(SchedulerOptions SchedulerOptions, OnnxImage Result);
10+
public record BatchVideoResult(SchedulerOptions SchedulerOptions, OnnxVideo Result);
711
}

OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public interface IPipeline
101101
/// <param name="progressCallback">The progress callback.</param>
102102
/// <param name="cancellationToken">The cancellation token.</param>
103103
/// <returns></returns>
104-
IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
104+
IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
105105

106106

107107
/// <summary>
@@ -126,6 +126,6 @@ public interface IPipeline
126126
/// <param name="progressCallback">The progress callback.</param>
127127
/// <param name="cancellationToken">The cancellation token.</param>
128128
/// <returns></returns>
129-
IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
129+
IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
130130
}
131131
}

OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
133133
/// <param name="progressCallback">The progress callback.</param>
134134
/// <param name="cancellationToken">The cancellation token.</param>
135135
/// <returns></returns>
136-
public abstract IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
136+
public abstract IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
137137

138138

139139
/// <summary>
@@ -158,7 +158,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
158158
/// <param name="progressCallback">The progress callback.</param>
159159
/// <param name="cancellationToken">The cancellation token.</param>
160160
/// <returns></returns>
161-
public abstract IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
161+
public abstract IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
162162

163163

164164
/// <summary>

0 commit comments

Comments
 (0)