Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit eaec794

Browse files
authored
Merge pull request #95 from saddam213/BatchVideo
Batch Video Processing
2 parents fee8c5d + ec00c32 commit eaec794

File tree

1 file changed

+71
-32
lines changed

1 file changed

+71
-32
lines changed

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOpt
105105
// Create random seed if none was set
106106
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
107107

108-
var diffuseTime = _logger?.LogBegin("Diffuse starting...");
108+
var diffuseTime = _logger?.LogBegin("Diffuser starting...");
109109
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}");
110110

111111
// Check guidance
@@ -114,36 +114,15 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOpt
114114
// Process prompts
115115
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions.BaseModel, promptOptions, performGuidance);
116116

117-
// If video input, process frames
118-
if (promptOptions.HasInputVideo)
119-
{
120-
var frameIndex = 0;
121-
DenseTensor<float> videoTensor = null;
122-
var videoFrames = promptOptions.InputVideo.VideoFrames.Frames;
123-
var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex);
124-
foreach (var videoFrame in videoFrames)
125-
{
126-
frameIndex++;
127-
promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame);
128-
promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame);
129-
var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);
130-
131-
// Frame Progress
132-
ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor);
117+
var tensorResult = promptOptions.HasInputVideo
118+
? await DiffuseVideoAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken)
119+
: await DiffuseImageAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
133120

134-
// Concatenate frame
135-
videoTensor = videoTensor.Concatenate(frameResultTensor);
136-
}
121+
_logger?.LogEnd($"Diffuser complete", diffuseTime);
122+
return tensorResult;
123+
}
137124

138-
_logger?.LogEnd($"Diffuse complete", diffuseTime);
139-
return videoTensor;
140-
}
141125

142-
// Run Scheduler steps
143-
var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
144-
_logger?.LogEnd($"Diffuse complete", diffuseTime);
145-
return schedulerResult;
146-
}
147126

148127

149128

@@ -180,13 +159,73 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(ModelOption
180159
var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex);
181160
foreach (var batchSchedulerOption in batchSchedulerOptions)
182161
{
183-
var diffuseTime = _logger?.LogBegin("Diffuse starting...");
184-
yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken));
185-
_logger?.LogEnd($"Diffuse complete", diffuseTime);
162+
var tensorResult = promptOptions.HasInputVideo
163+
? await DiffuseVideoAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken)
164+
: await DiffuseImageAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken);
165+
166+
yield return new BatchResult(batchSchedulerOption, tensorResult);
186167
batchIndex++;
187168
}
188169

189-
_logger?.LogEnd($"Diffuse batch complete", diffuseBatchTime);
170+
_logger?.LogEnd($"Batch Diffuser complete", diffuseBatchTime);
171+
}
172+
173+
174+
/// <summary>
175+
/// Diffuses the image.
176+
/// </summary>
177+
/// <param name="modelOptions">The model options.</param>
178+
/// <param name="promptOptions">The prompt options.</param>
179+
/// <param name="schedulerOptions">The scheduler options.</param>
180+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
181+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
182+
/// <param name="progressCallback">The progress callback.</param>
183+
/// <param name="cancellationToken">The cancellation token.</param>
184+
/// <returns></returns>
185+
protected virtual async Task<DenseTensor<float>> DiffuseImageAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
186+
{
187+
var diffuseTime = _logger?.LogBegin("Image Diffuser starting...");
188+
var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
189+
_logger?.LogEnd($"Image Diffuser complete", diffuseTime);
190+
return schedulerResult;
191+
}
192+
193+
194+
/// <summary>
195+
/// Diffuses the video.
196+
/// </summary>
197+
/// <param name="modelOptions">The model options.</param>
198+
/// <param name="promptOptions">The prompt options.</param>
199+
/// <param name="schedulerOptions">The scheduler options.</param>
200+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
201+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
202+
/// <param name="progressCallback">The progress callback.</param>
203+
/// <param name="cancellationToken">The cancellation token.</param>
204+
/// <returns></returns>
205+
protected virtual async Task<DenseTensor<float>> DiffuseVideoAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
206+
{
207+
var diffuseTime = _logger?.LogBegin("Video Diffuser starting...");
208+
209+
var frameIndex = 0;
210+
DenseTensor<float> videoTensor = null;
211+
var videoFrames = promptOptions.InputVideo.VideoFrames.Frames;
212+
var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex);
213+
foreach (var videoFrame in videoFrames)
214+
{
215+
frameIndex++;
216+
promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame);
217+
promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame);
218+
var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);
219+
220+
// Frame Progress
221+
ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor);
222+
223+
// Concatenate frame
224+
videoTensor = videoTensor.Concatenate(frameResultTensor);
225+
}
226+
227+
_logger?.LogEnd($"Video Diffuser complete", diffuseTime);
228+
return videoTensor;
190229
}
191230

192231

0 commit comments

Comments
 (0)