Skip to content

Commit ddf7ddc

Browse files
authored
Add sdxl generation preview (#3862)
## What type of PR is this? (check all applicable) - [ ] Refactor - [x] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Description Add progress preview for sdxl generation nodes
2 parents 20757d1 + 4a0774b commit ddf7ddc

File tree

3 files changed

+140
-2
lines changed

3 files changed

+140
-2
lines changed

invokeai/app/invocations/latent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
764764
dtype=vae.dtype
765765
) # FIXME: uses torch.randn. make reproducible!
766766

767-
latents = 0.18215 * latents
767+
latents = vae.config.scaling_factor * latents
768768
latents = latents.to(dtype=orig_dtype)
769769

770770
name = f"{context.graph_execution_state_id}__{self.id}"

invokeai/app/invocations/sdxl.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pydantic import Field, validator
77

88
from ...backend.model_management import ModelType, SubModelType
9+
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
910
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
1011
InvocationConfig, InvocationContext)
1112

@@ -243,10 +244,31 @@ class Config(InvocationConfig):
243244
},
244245
}
245246

247+
def dispatch_progress(
248+
self,
249+
context: InvocationContext,
250+
source_node_id: str,
251+
sample,
252+
step,
253+
total_steps,
254+
) -> None:
255+
stable_diffusion_xl_step_callback(
256+
context=context,
257+
node=self.dict(),
258+
source_node_id=source_node_id,
259+
sample=sample,
260+
step=step,
261+
total_steps=total_steps,
262+
)
263+
246264
# based on
247265
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
248266
@torch.no_grad()
249267
def invoke(self, context: InvocationContext) -> LatentsOutput:
268+
graph_execution_state = context.services.graph_execution_manager.get(
269+
context.graph_execution_state_id
270+
)
271+
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
250272
latents = context.services.latents.get(self.noise.latents_name)
251273

252274
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@@ -341,6 +363,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
341363
# call the callback, if provided
342364
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
343365
progress_bar.update()
366+
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
344367
#if callback is not None and i % callback_steps == 0:
345368
# callback(i, t, latents)
346369
else:
@@ -409,6 +432,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
409432
# call the callback, if provided
410433
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
411434
progress_bar.update()
435+
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
412436
#if callback is not None and i % callback_steps == 0:
413437
# callback(i, t, latents)
414438

@@ -473,10 +497,31 @@ class Config(InvocationConfig):
473497
},
474498
}
475499

500+
def dispatch_progress(
501+
self,
502+
context: InvocationContext,
503+
source_node_id: str,
504+
sample,
505+
step,
506+
total_steps,
507+
) -> None:
508+
stable_diffusion_xl_step_callback(
509+
context=context,
510+
node=self.dict(),
511+
source_node_id=source_node_id,
512+
sample=sample,
513+
step=step,
514+
total_steps=total_steps,
515+
)
516+
476517
# based on
477518
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
478519
@torch.no_grad()
479520
def invoke(self, context: InvocationContext) -> LatentsOutput:
521+
graph_execution_state = context.services.graph_execution_manager.get(
522+
context.graph_execution_state_id
523+
)
524+
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
480525
latents = context.services.latents.get(self.latents.latents_name)
481526

482527
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@@ -579,6 +624,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
579624
# call the callback, if provided
580625
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
581626
progress_bar.update()
627+
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
582628
#if callback is not None and i % callback_steps == 0:
583629
# callback(i, t, latents)
584630
else:
@@ -647,6 +693,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
647693
# call the callback, if provided
648694
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
649695
progress_bar.update()
696+
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
650697
#if callback is not None and i % callback_steps == 0:
651698
# callback(i, t, latents)
652699

invokeai/app/util/step_callback.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,30 @@
1+
import torch
2+
from PIL import Image
13
from invokeai.app.models.exceptions import CanceledException
24
from invokeai.app.models.image import ProgressImage
35
from ..invocations.baseinvocation import InvocationContext
46
from ...backend.util.util import image_to_dataURL
57
from ...backend.generator.base import Generator
68
from ...backend.stable_diffusion import PipelineIntermediateState
9+
from invokeai.app.services.config import InvokeAIAppConfig
10+
11+
12+
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
13+
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
14+
15+
if smooth_matrix is not None:
16+
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
17+
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
18+
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
19+
20+
latents_ubyte = (
21+
((latent_image + 1) / 2)
22+
.clamp(0, 1) # change scale from -1..1 to 0..1
23+
.mul(0xFF) # to 0..255
24+
.byte()
25+
).cpu()
26+
27+
return Image.fromarray(latents_ubyte.numpy())
728

829

930
def stable_diffusion_step_callback(
@@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
3758
# step = intermediate_state.step
3859

3960
# TODO: only output a preview image when requested
40-
image = Generator.sample_to_lowres_estimated_image(sample)
61+
62+
# origingally adapted from code by @erucipe and @keturn here:
63+
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
64+
65+
# these updated numbers for v1.5 are from @torridgristle
66+
v1_5_latent_rgb_factors = torch.tensor(
67+
[
68+
# R G B
69+
[0.3444, 0.1385, 0.0670], # L1
70+
[0.1247, 0.4027, 0.1494], # L2
71+
[-0.3192, 0.2513, 0.2103], # L3
72+
[-0.1307, -0.1874, -0.7445], # L4
73+
],
74+
dtype=sample.dtype,
75+
device=sample.device,
76+
)
77+
78+
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
4179

4280
(width, height) = image.size
4381
width *= 8
@@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
5391
step=intermediate_state.step,
5492
total_steps=node["steps"],
5593
)
94+
95+
def stable_diffusion_xl_step_callback(
96+
context: InvocationContext,
97+
node: dict,
98+
source_node_id: str,
99+
sample,
100+
step,
101+
total_steps,
102+
):
103+
if context.services.queue.is_canceled(context.graph_execution_state_id):
104+
raise CanceledException
105+
106+
sdxl_latent_rgb_factors = torch.tensor(
107+
[
108+
# R G B
109+
[ 0.3816, 0.4930, 0.5320],
110+
[-0.3753, 0.1631, 0.1739],
111+
[ 0.1770, 0.3588, -0.2048],
112+
[-0.4350, -0.2644, -0.4289],
113+
],
114+
dtype=sample.dtype,
115+
device=sample.device,
116+
)
117+
118+
sdxl_smooth_matrix = torch.tensor(
119+
[
120+
#[ 0.0478, 0.1285, 0.0478],
121+
#[ 0.1285, 0.2948, 0.1285],
122+
#[ 0.0478, 0.1285, 0.0478],
123+
[0.0358, 0.0964, 0.0358],
124+
[0.0964, 0.4711, 0.0964],
125+
[0.0358, 0.0964, 0.0358],
126+
],
127+
dtype=sample.dtype,
128+
device=sample.device,
129+
)
130+
131+
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
132+
133+
(width, height) = image.size
134+
width *= 8
135+
height *= 8
136+
137+
dataURL = image_to_dataURL(image, image_format="JPEG")
138+
139+
context.services.events.emit_generator_progress(
140+
graph_execution_state_id=context.graph_execution_state_id,
141+
node=node,
142+
source_node_id=source_node_id,
143+
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
144+
step=step,
145+
total_steps=total_steps,
146+
)

0 commit comments

Comments
 (0)