Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 28267ce

Browse files
committed
Make timeids compatible with more model types
1 parent 9b4d450 commit 28267ce

File tree

3 files changed

+13
-25
lines changed

3 files changed

+13
-25
lines changed

OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
7272
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
7373

7474
// Get Time ids
75-
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);
75+
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions);
7676

7777
// Loop though the timesteps
7878
var step = 0;
@@ -86,6 +86,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
8686
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
8787
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
8888
var timestepTensor = CreateTimestepTensor(timestep);
89+
var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
8990

9091
var outputChannels = performGuidance ? 2 : 1;
9192
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
@@ -95,7 +96,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
9596
inferenceParameters.AddInputTensor(timestepTensor);
9697
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
9798
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
98-
inferenceParameters.AddInputTensor(addTimeIds);
99+
inferenceParameters.AddInputTensor(timeids);
99100
inferenceParameters.AddOutputBuffer(outputDimension);
100101

101102
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);

OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
7272
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
7373

7474
// Get Time ids
75-
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);
75+
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions);
7676

7777
// Loop though the timesteps
7878
var step = 0;
@@ -86,6 +86,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
8686
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
8787
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
8888
var timestepTensor = CreateTimestepTensor(timestep);
89+
var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
8990

9091
var outputChannels = performGuidance ? 2 : 1;
9192
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
@@ -95,7 +96,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
9596
inferenceParameters.AddInputTensor(timestepTensor);
9697
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
9798
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
98-
inferenceParameters.AddInputTensor(addTimeIds);
99+
inferenceParameters.AddInputTensor(timeids);
99100
inferenceParameters.AddOutputBuffer(outputDimension);
100101

101102
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);

OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
6060
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
6161

6262
// Get Time ids
63-
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);
63+
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions);
6464

6565
// Loop though the timesteps
6666
var step = 0;
@@ -74,6 +74,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
7474
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
7575
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
7676
var timestepTensor = CreateTimestepTensor(timestep);
77+
var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
7778

7879
var outputChannels = performGuidance ? 2 : 1;
7980
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
@@ -83,7 +84,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
8384
inferenceParameters.AddInputTensor(timestepTensor);
8485
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
8586
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
86-
inferenceParameters.AddInputTensor(addTimeIds);
87+
inferenceParameters.AddInputTensor(timeids);
8788
inferenceParameters.AddOutputBuffer(outputDimension);
8889

8990
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
@@ -115,26 +116,11 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
115116
/// </summary>
116117
/// <param name="schedulerOptions">The scheduler options.</param>
117118
/// <returns></returns>
118-
protected DenseTensor<float> GetAddTimeIds(StableDiffusionModelSet model, SchedulerOptions schedulerOptions, bool performGuidance)
119+
protected DenseTensor<float> GetAddTimeIds(StableDiffusionModelSet model, SchedulerOptions schedulerOptions)
119120
{
120-
float[] result;
121-
if (model.ModelType == ModelType.Refiner)
122-
{
123-
//original_size + crops_coords_top_left + aesthetic_score
124-
//original_size + crops_coords_top_left + negative_aesthetic_score
125-
result = !performGuidance
126-
? new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticScore }
127-
: new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticNegativeScore, schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticScore };
128-
}
129-
else
130-
{
131-
//original_size + crops_coords_top_left + target_size
132-
//original_size + crops_coords_top_left + negative_target_size
133-
result = !performGuidance
134-
? new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width }
135-
: new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width, schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width };
136-
}
137-
121+
float[] result = model.ModelType == ModelType.Refiner
122+
? new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.AestheticScore }
123+
: new float[] { schedulerOptions.Height, schedulerOptions.Width, 0, 0, schedulerOptions.Height, schedulerOptions.Width };
138124
return TensorHelper.CreateTensor(result, new[] { 1, result.Length });
139125
}
140126

0 commit comments

Comments
 (0)