Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d1922ad

Browse files
committed
Tidy up StableDiffusionPipeline codebase
1 parent 4f0ae87 commit d1922ad

File tree

2 files changed

+94
-119
lines changed

2 files changed

+94
-119
lines changed

OnnxStack.Core/Video/OnnxVideo.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ public OnnxVideo(VideoInfo info, DenseTensor<float> videoTensor)
4141
}
4242

4343

44+
/// <summary>
45+
/// Initializes a new instance of the <see cref="OnnxVideo"/> class.
46+
/// </summary>
47+
/// <param name="info">The information.</param>
48+
/// <param name="videoTensors">The video tensors.</param>
49+
public OnnxVideo(VideoInfo info, IEnumerable<DenseTensor<float>> videoTensors)
50+
{
51+
_info = info;
52+
_frames = videoTensors
53+
.Select(x => new OnnxImage(x))
54+
.ToList();
55+
}
56+
57+
4458
/// <summary>
4559
/// Gets the height.
4660
/// </summary>

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 80 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public class StableDiffusionPipeline : PipelineBase
3434
protected IReadOnlyList<SchedulerType> _supportedSchedulers;
3535
protected SchedulerOptions _defaultSchedulerOptions;
3636

37+
protected sealed record BatchResultInternal(SchedulerOptions SchedulerOptions, List<DenseTensor<float>> Result);
38+
3739
/// <summary>
3840
/// Initializes a new instance of the <see cref="StableDiffusionPipeline"/> class.
3941
/// </summary>
@@ -165,35 +167,10 @@ public override void ValidateInputs(PromptOptions promptOptions, SchedulerOption
165167
/// <returns></returns>
166168
public override async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
167169
{
168-
var diffuseTime = _logger?.LogBegin("Diffuser starting...");
169-
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
170-
_logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}");
171-
172-
// Check guidance
173-
var performGuidance = ShouldPerformGuidance(options);
174-
175-
// Process prompts
176-
var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance);
177-
178-
// Create Diffuser
179-
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
180-
181-
// Diffuse
182-
var tensorResult = default(DenseTensor<float>);
183-
if (promptOptions.HasInputVideo)
184-
{
185-
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken))
186-
{
187-
tensorResult = tensorResult.Concatenate(frameTensor);
188-
}
189-
}
190-
else
191-
{
192-
tensorResult = await DiffuseImageAsync(diffuser, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
193-
}
194-
195-
_logger?.LogEnd($"Diffuser complete", diffuseTime);
196-
return tensorResult;
170+
var tensors = await RunInternalAsync(promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken);
171+
return tensors.Count == 1
172+
? tensors.First() // ImageTensor
173+
: tensors.Join(); // VideoTensor
197174
}
198175

199176

@@ -209,45 +186,13 @@ public override async Task<DenseTensor<float>> RunAsync(PromptOptions promptOpti
209186
/// <returns></returns>
210187
public override async IAsyncEnumerable<BatchResult> RunBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
211188
{
212-
var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting...");
213-
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
214-
_logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}");
215-
_logger?.Log($"BatchType: {batchOptions.BatchType}, ValueFrom: {batchOptions.ValueFrom}, ValueTo: {batchOptions.ValueTo}, Increment: {batchOptions.Increment}");
216-
217-
// Check guidance
218-
var performGuidance = ShouldPerformGuidance(options);
219-
220-
// Process prompts
221-
var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance);
222-
223-
// Generate batch options
224-
var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, options);
225-
226-
// Create Diffuser
227-
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
228-
229-
// Diffuse
230-
var batchIndex = 1;
231-
var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex);
232-
foreach (var batchSchedulerOption in batchSchedulerOptions)
189+
await foreach (var batchResult in RunBatchInternalAsync(batchOptions, promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken))
233190
{
234-
var tensorResult = default(DenseTensor<float>);
235-
if (promptOptions.HasInputVideo)
236-
{
237-
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken))
238-
{
239-
tensorResult = tensorResult.Concatenate(frameTensor);
240-
}
241-
}
242-
else
243-
{
244-
tensorResult = await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
245-
}
246-
yield return new BatchResult(batchSchedulerOption, tensorResult);
247-
batchIndex++;
191+
var tensor = batchResult.Result.Count == 1
192+
? batchResult.Result.First() // ImageTensor
193+
: batchResult.Result.Join(); // VideoTensor
194+
yield return new BatchResult(batchResult.SchedulerOptions, tensor);
248195
}
249-
250-
_logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime);
251196
}
252197

253198

@@ -262,22 +207,8 @@ public override async IAsyncEnumerable<BatchResult> RunBatchAsync(BatchOptions b
262207
/// <returns></returns>
263208
public override async Task<OnnxImage> GenerateImageAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
264209
{
265-
var diffuseTime = _logger?.LogBegin("Diffuser starting...");
266-
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
267-
_logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}");
268-
269-
// Check guidance
270-
var performGuidance = ShouldPerformGuidance(options);
271-
272-
// Process prompts
273-
var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance);
274-
275-
// Create Diffuser
276-
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
277-
278-
var imageResult = await DiffuseImageAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
279-
280-
return new OnnxImage(imageResult);
210+
var tensors = await RunInternalAsync(promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken);
211+
return new OnnxImage(tensors.First());
281212
}
282213

283214

@@ -293,47 +224,58 @@ public override async Task<OnnxImage> GenerateImageAsync(PromptOptions promptOpt
293224
/// <returns></returns>
294225
public override async IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
295226
{
296-
var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting...");
297-
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
298-
_logger?.Log($"Model: {Name}, Pipeline: {PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {options.SchedulerType}");
299-
_logger?.Log($"BatchType: {batchOptions.BatchType}, ValueFrom: {batchOptions.ValueFrom}, ValueTo: {batchOptions.ValueTo}, Increment: {batchOptions.Increment}");
300-
301-
// Check guidance
302-
var performGuidance = ShouldPerformGuidance(options);
227+
await foreach (var batchResult in RunBatchInternalAsync(batchOptions, promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken))
228+
{
229+
yield return new BatchImageResult(batchResult.SchedulerOptions, new OnnxImage(batchResult.Result.First()));
230+
}
231+
}
303232

304-
// Process prompts
305-
var promptEmbeddings = await CreatePromptEmbedsAsync(promptOptions, performGuidance);
306233

307-
// Generate batch options
308-
var batchSchedulerOptions = BatchGenerator.GenerateBatch(this, batchOptions, options);
234+
/// <summary>
235+
/// Runs the pipeline returning the result as an OnnxVideo.
236+
/// </summary>
237+
/// <param name="promptOptions">The prompt options.</param>
238+
/// <param name="schedulerOptions">The scheduler options.</param>
239+
/// <param name="controlNet">The control net.</param>
240+
/// <param name="progressCallback">The progress callback.</param>
241+
/// <param name="cancellationToken">The cancellation token.</param>
242+
/// <returns></returns>
243+
public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
244+
{
245+
var tensors = await RunInternalAsync(promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken);
246+
return new OnnxVideo(promptOptions.InputVideo.Info, tensors);
247+
}
309248

310-
// Create Diffuser
311-
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
312249

313-
// Diffuse
314-
var batchIndex = 1;
315-
var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex);
316-
foreach (var batchSchedulerOption in batchSchedulerOptions)
250+
/// <summary>
251+
/// Runs the batch pipeline returning the result as an OnnxVideo.
252+
/// </summary>
253+
/// <param name="batchOptions">The batch options.</param>
254+
/// <param name="promptOptions">The prompt options.</param>
255+
/// <param name="schedulerOptions">The scheduler options.</param>
256+
/// <param name="controlNet">The control net.</param>
257+
/// <param name="progressCallback">The progress callback.</param>
258+
/// <param name="cancellationToken">The cancellation token.</param>
259+
/// <returns></returns>
260+
public override async IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
261+
{
262+
await foreach (var batchResult in RunBatchInternalAsync(batchOptions, promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken))
317263
{
318-
var tensorResult = await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
319-
yield return new BatchImageResult(batchSchedulerOption, new OnnxImage(tensorResult));
320-
batchIndex++;
264+
yield return new BatchVideoResult(batchResult.SchedulerOptions, new OnnxVideo(promptOptions.InputVideo.Info, batchResult.Result));
321265
}
322-
323-
_logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime);
324266
}
325267

326268

327269
/// <summary>
328-
/// Runs the pipeline returning the result as an OnnxVideo.
270+
/// Runs the pipeline
329271
/// </summary>
330272
/// <param name="promptOptions">The prompt options.</param>
331273
/// <param name="schedulerOptions">The scheduler options.</param>
332274
/// <param name="controlNet">The control net.</param>
333275
/// <param name="progressCallback">The progress callback.</param>
334276
/// <param name="cancellationToken">The cancellation token.</param>
335277
/// <returns></returns>
336-
public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
278+
protected virtual async Task<List<DenseTensor<float>>> RunInternalAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
337279
{
338280
var diffuseTime = _logger?.LogBegin("Diffuser starting...");
339281
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
@@ -348,17 +290,30 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
348290
// Create Diffuser
349291
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
350292

351-
var frames = new List<OnnxImage>();
352-
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken))
293+
// Diffuse
294+
var tensorResult = new List<DenseTensor<float>>();
295+
if (promptOptions.HasInputVideo)
353296
{
354-
frames.Add(new OnnxImage(frameTensor));
297+
var frameIndex = 1;
298+
var frameSchedulerCallback = CreateBatchCallback(progressCallback, promptOptions.InputVideo.Frames.Count, () => frameIndex);
299+
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, frameSchedulerCallback, cancellationToken))
300+
{
301+
frameIndex++;
302+
tensorResult.Add(frameTensor);
303+
}
355304
}
356-
return new OnnxVideo(promptOptions.InputVideo.Info, frames);
305+
else
306+
{
307+
tensorResult.Add(await DiffuseImageAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken));
308+
}
309+
310+
_logger?.LogEnd($"Diffuser complete", diffuseTime);
311+
return tensorResult;
357312
}
358313

359314

360315
/// <summary>
361-
/// Runs the batch pipeline returning the result as an OnnxVideo.
316+
/// Runs the pipeline batch.
362317
/// </summary>
363318
/// <param name="batchOptions">The batch options.</param>
364319
/// <param name="promptOptions">The prompt options.</param>
@@ -367,7 +322,7 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
367322
/// <param name="progressCallback">The progress callback.</param>
368323
/// <param name="cancellationToken">The cancellation token.</param>
369324
/// <returns></returns>
370-
public override async IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
325+
protected virtual async IAsyncEnumerable<BatchResultInternal> RunBatchInternalAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
371326
{
372327
var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting...");
373328
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
@@ -387,19 +342,26 @@ public override async IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync
387342
var diffuser = CreateDiffuser(promptOptions.DiffuserType, controlNet);
388343

389344
// Diffuse
390-
var batchIndex = 1;
345+
var batchIndex = 1;// TODO: Video batch callback shoud be (BatchIndex + FrameIndex), not (BatchIndex + StepIndex)
391346
var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex);
392347
foreach (var batchSchedulerOption in batchSchedulerOptions)
393348
{
394-
var frames = new List<OnnxImage>();
395-
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken))
349+
var tensorResult = new List<DenseTensor<float>>();
350+
if (promptOptions.HasInputVideo)
396351
{
397-
frames.Add(new OnnxImage(frameTensor));
352+
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken))
353+
{
354+
tensorResult.Add(frameTensor);
355+
}
356+
}
357+
else
358+
{
359+
tensorResult.Add(await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken));
398360
}
399-
yield return new BatchVideoResult(batchSchedulerOption, new OnnxVideo(promptOptions.InputVideo.Info, frames));
361+
400362
batchIndex++;
363+
yield return new BatchResultInternal(batchSchedulerOption, tensorResult);
401364
}
402-
403365
_logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime);
404366
}
405367

@@ -623,5 +585,4 @@ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelTy
623585
return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusion, modelType, deviceId, executionProvider, memoryMode), logger);
624586
}
625587
}
626-
627588
}

0 commit comments

Comments
 (0)