@@ -34,6 +34,8 @@ public class StableDiffusionPipeline : PipelineBase
34
34
protected IReadOnlyList < SchedulerType > _supportedSchedulers ;
35
35
protected SchedulerOptions _defaultSchedulerOptions ;
36
36
37
+ protected sealed record BatchResultInternal ( SchedulerOptions SchedulerOptions , List < DenseTensor < float > > Result ) ;
38
+
37
39
/// <summary>
38
40
/// Initializes a new instance of the <see cref="StableDiffusionPipeline"/> class.
39
41
/// </summary>
@@ -165,35 +167,10 @@ public override void ValidateInputs(PromptOptions promptOptions, SchedulerOption
165
167
/// <returns></returns>
166
168
public override async Task < DenseTensor < float > > RunAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
167
169
{
168
- var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
169
- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
170
- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
171
-
172
- // Check guidance
173
- var performGuidance = ShouldPerformGuidance ( options ) ;
174
-
175
- // Process prompts
176
- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
177
-
178
- // Create Diffuser
179
- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
180
-
181
- // Diffuse
182
- var tensorResult = default ( DenseTensor < float > ) ;
183
- if ( promptOptions . HasInputVideo )
184
- {
185
- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
186
- {
187
- tensorResult = tensorResult . Concatenate ( frameTensor ) ;
188
- }
189
- }
190
- else
191
- {
192
- tensorResult = await DiffuseImageAsync ( diffuser , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
193
- }
194
-
195
- _logger ? . LogEnd ( $ "Diffuser complete", diffuseTime ) ;
196
- return tensorResult ;
170
+ var tensors = await RunInternalAsync ( promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) ;
171
+ return tensors . Count == 1
172
+ ? tensors . First ( ) // ImageTensor
173
+ : tensors . Join ( ) ; // VideoTensor
197
174
}
198
175
199
176
@@ -209,45 +186,13 @@ public override async Task<DenseTensor<float>> RunAsync(PromptOptions promptOpti
209
186
/// <returns></returns>
210
187
public override async IAsyncEnumerable < BatchResult > RunBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
211
188
{
212
- var diffuseBatchTime = _logger ? . LogBegin ( "Batch Diffuser starting..." ) ;
213
- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
214
- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
215
- _logger ? . Log ( $ "BatchType: { batchOptions . BatchType } , ValueFrom: { batchOptions . ValueFrom } , ValueTo: { batchOptions . ValueTo } , Increment: { batchOptions . Increment } ") ;
216
-
217
- // Check guidance
218
- var performGuidance = ShouldPerformGuidance ( options ) ;
219
-
220
- // Process prompts
221
- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
222
-
223
- // Generate batch options
224
- var batchSchedulerOptions = BatchGenerator . GenerateBatch ( this , batchOptions , options ) ;
225
-
226
- // Create Diffuser
227
- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
228
-
229
- // Diffuse
230
- var batchIndex = 1 ;
231
- var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
232
- foreach ( var batchSchedulerOption in batchSchedulerOptions )
189
+ await foreach ( var batchResult in RunBatchInternalAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
233
190
{
234
- var tensorResult = default ( DenseTensor < float > ) ;
235
- if ( promptOptions . HasInputVideo )
236
- {
237
- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
238
- {
239
- tensorResult = tensorResult . Concatenate ( frameTensor ) ;
240
- }
241
- }
242
- else
243
- {
244
- tensorResult = await DiffuseImageAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
245
- }
246
- yield return new BatchResult ( batchSchedulerOption , tensorResult ) ;
247
- batchIndex ++ ;
191
+ var tensor = batchResult . Result . Count == 1
192
+ ? batchResult . Result . First ( ) // ImageTensor
193
+ : batchResult . Result . Join ( ) ; // VideoTensor
194
+ yield return new BatchResult ( batchResult . SchedulerOptions , tensor ) ;
248
195
}
249
-
250
- _logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
251
196
}
252
197
253
198
@@ -262,22 +207,8 @@ public override async IAsyncEnumerable<BatchResult> RunBatchAsync(BatchOptions b
262
207
/// <returns></returns>
263
208
public override async Task < OnnxImage > GenerateImageAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
264
209
{
265
- var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
266
- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
267
- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
268
-
269
- // Check guidance
270
- var performGuidance = ShouldPerformGuidance ( options ) ;
271
-
272
- // Process prompts
273
- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
274
-
275
- // Create Diffuser
276
- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
277
-
278
- var imageResult = await DiffuseImageAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
279
-
280
- return new OnnxImage ( imageResult ) ;
210
+ var tensors = await RunInternalAsync ( promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) ;
211
+ return new OnnxImage ( tensors . First ( ) ) ;
281
212
}
282
213
283
214
@@ -293,47 +224,58 @@ public override async Task<OnnxImage> GenerateImageAsync(PromptOptions promptOpt
293
224
/// <returns></returns>
294
225
public override async IAsyncEnumerable < BatchImageResult > GenerateImageBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
295
226
{
296
- var diffuseBatchTime = _logger ? . LogBegin ( "Batch Diffuser starting..." ) ;
297
- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
298
- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
299
- _logger ? . Log ( $ "BatchType: { batchOptions . BatchType } , ValueFrom: { batchOptions . ValueFrom } , ValueTo: { batchOptions . ValueTo } , Increment: { batchOptions . Increment } ") ;
300
-
301
- // Check guidance
302
- var performGuidance = ShouldPerformGuidance ( options ) ;
227
+ await foreach ( var batchResult in RunBatchInternalAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
228
+ {
229
+ yield return new BatchImageResult ( batchResult . SchedulerOptions , new OnnxImage ( batchResult . Result . First ( ) ) ) ;
230
+ }
231
+ }
303
232
304
- // Process prompts
305
- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
306
233
307
- // Generate batch options
308
- var batchSchedulerOptions = BatchGenerator . GenerateBatch ( this , batchOptions , options ) ;
234
+ /// <summary>
235
+ /// Runs the pipeline returning the result as an OnnxVideo.
236
+ /// </summary>
237
+ /// <param name="promptOptions">The prompt options.</param>
238
+ /// <param name="schedulerOptions">The scheduler options.</param>
239
+ /// <param name="controlNet">The control net.</param>
240
+ /// <param name="progressCallback">The progress callback.</param>
241
+ /// <param name="cancellationToken">The cancellation token.</param>
242
+ /// <returns></returns>
243
+ public override async Task < OnnxVideo > GenerateVideoAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
244
+ {
245
+ var tensors = await RunInternalAsync ( promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) ;
246
+ return new OnnxVideo ( promptOptions . InputVideo . Info , tensors ) ;
247
+ }
309
248
310
- // Create Diffuser
311
- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
312
249
313
- // Diffuse
314
- var batchIndex = 1 ;
315
- var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
316
- foreach ( var batchSchedulerOption in batchSchedulerOptions )
250
+ /// <summary>
251
+ /// Runs the batch pipeline returning the result as an OnnxVideo.
252
+ /// </summary>
253
+ /// <param name="batchOptions">The batch options.</param>
254
+ /// <param name="promptOptions">The prompt options.</param>
255
+ /// <param name="schedulerOptions">The scheduler options.</param>
256
+ /// <param name="controlNet">The control net.</param>
257
+ /// <param name="progressCallback">The progress callback.</param>
258
+ /// <param name="cancellationToken">The cancellation token.</param>
259
+ /// <returns></returns>
260
+ public override async IAsyncEnumerable < BatchVideoResult > GenerateVideoBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
261
+ {
262
+ await foreach ( var batchResult in RunBatchInternalAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
317
263
{
318
- var tensorResult = await DiffuseImageAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
319
- yield return new BatchImageResult ( batchSchedulerOption , new OnnxImage ( tensorResult ) ) ;
320
- batchIndex ++ ;
264
+ yield return new BatchVideoResult ( batchResult . SchedulerOptions , new OnnxVideo ( promptOptions . InputVideo . Info , batchResult . Result ) ) ;
321
265
}
322
-
323
- _logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
324
266
}
325
267
326
268
327
269
/// <summary>
328
- /// Runs the pipeline returning the result as an OnnxVideo.
270
+ /// Runs the pipeline
329
271
/// </summary>
330
272
/// <param name="promptOptions">The prompt options.</param>
331
273
/// <param name="schedulerOptions">The scheduler options.</param>
332
274
/// <param name="controlNet">The control net.</param>
333
275
/// <param name="progressCallback">The progress callback.</param>
334
276
/// <param name="cancellationToken">The cancellation token.</param>
335
277
/// <returns></returns>
336
- public override async Task < OnnxVideo > GenerateVideoAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
278
+ protected virtual async Task < List < DenseTensor < float > > > RunInternalAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
337
279
{
338
280
var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
339
281
var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
@@ -348,17 +290,30 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
348
290
// Create Diffuser
349
291
var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
350
292
351
- var frames = new List < OnnxImage > ( ) ;
352
- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
293
+ // Diffuse
294
+ var tensorResult = new List < DenseTensor < float > > ( ) ;
295
+ if ( promptOptions . HasInputVideo )
353
296
{
354
- frames . Add ( new OnnxImage ( frameTensor ) ) ;
297
+ var frameIndex = 1 ;
298
+ var frameSchedulerCallback = CreateBatchCallback ( progressCallback , promptOptions . InputVideo . Frames . Count , ( ) => frameIndex ) ;
299
+ await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , frameSchedulerCallback , cancellationToken ) )
300
+ {
301
+ frameIndex ++ ;
302
+ tensorResult . Add ( frameTensor ) ;
303
+ }
355
304
}
356
- return new OnnxVideo ( promptOptions . InputVideo . Info , frames ) ;
305
+ else
306
+ {
307
+ tensorResult . Add ( await DiffuseImageAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ) ;
308
+ }
309
+
310
+ _logger ? . LogEnd ( $ "Diffuser complete", diffuseTime ) ;
311
+ return tensorResult ;
357
312
}
358
313
359
314
360
315
/// <summary>
361
- /// Runs the batch pipeline returning the result as an OnnxVideo .
316
+ /// Runs the pipeline batch .
362
317
/// </summary>
363
318
/// <param name="batchOptions">The batch options.</param>
364
319
/// <param name="promptOptions">The prompt options.</param>
@@ -367,7 +322,7 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
367
322
/// <param name="progressCallback">The progress callback.</param>
368
323
/// <param name="cancellationToken">The cancellation token.</param>
369
324
/// <returns></returns>
370
- public override async IAsyncEnumerable < BatchVideoResult > GenerateVideoBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
325
+ protected virtual async IAsyncEnumerable < BatchResultInternal > RunBatchInternalAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
371
326
{
372
327
var diffuseBatchTime = _logger ? . LogBegin ( "Batch Diffuser starting..." ) ;
373
328
var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
@@ -387,19 +342,26 @@ public override async IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync
387
342
var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
388
343
389
344
// Diffuse
390
- var batchIndex = 1 ;
345
+ var batchIndex = 1 ; // TODO: Video batch callback shoud be (BatchIndex + FrameIndex), not (BatchIndex + StepIndex)
391
346
var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
392
347
foreach ( var batchSchedulerOption in batchSchedulerOptions )
393
348
{
394
- var frames = new List < OnnxImage > ( ) ;
395
- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
349
+ var tensorResult = new List < DenseTensor < float > > ( ) ;
350
+ if ( promptOptions . HasInputVideo )
396
351
{
397
- frames . Add ( new OnnxImage ( frameTensor ) ) ;
352
+ await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) )
353
+ {
354
+ tensorResult . Add ( frameTensor ) ;
355
+ }
356
+ }
357
+ else
358
+ {
359
+ tensorResult . Add ( await DiffuseImageAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) ) ;
398
360
}
399
- yield return new BatchVideoResult ( batchSchedulerOption , new OnnxVideo ( promptOptions . InputVideo . Info , frames ) ) ;
361
+
400
362
batchIndex ++ ;
363
+ yield return new BatchResultInternal ( batchSchedulerOption , tensorResult ) ;
401
364
}
402
-
403
365
_logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
404
366
}
405
367
@@ -623,5 +585,4 @@ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelTy
623
585
return CreatePipeline ( ModelFactory . CreateModelSet ( modelFolder , DiffuserPipelineType . StableDiffusion , modelType , deviceId , executionProvider , memoryMode ) , logger ) ;
624
586
}
625
587
}
626
-
627
588
}
0 commit comments