Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!

latents = 0.18215 * latents
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)

name = f"{context.graph_execution_state_id}__{self.id}"
Expand Down
47 changes: 47 additions & 0 deletions invokeai/app/invocations/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import Field, validator

from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)

Expand Down Expand Up @@ -243,10 +244,31 @@ class Config(InvocationConfig):
},
}

def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)

# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name)

positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
Expand Down Expand Up @@ -341,6 +363,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
Expand Down Expand Up @@ -409,6 +432,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)

Expand Down Expand Up @@ -473,10 +497,31 @@ class Config(InvocationConfig):
},
}

def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)

# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name)

positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
Expand Down Expand Up @@ -579,6 +624,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
Expand Down Expand Up @@ -647,6 +693,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)

Expand Down
93 changes: 92 additions & 1 deletion invokeai/app/util/step_callback.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
import torch
from PIL import Image
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig


def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors

if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)

latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()

return Image.fromarray(latents_ubyte.numpy())


def stable_diffusion_step_callback(
Expand Down Expand Up @@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
# step = intermediate_state.step

# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)

# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)

image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)

(width, height) = image.size
width *= 8
Expand All @@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
step=intermediate_state.step,
total_steps=node["steps"],
)

def stable_diffusion_xl_step_callback(
context: InvocationContext,
node: dict,
source_node_id: str,
sample,
step,
total_steps,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException

sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[ 0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[ 0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)

sdxl_smooth_matrix = torch.tensor(
[
#[ 0.0478, 0.1285, 0.0478],
#[ 0.1285, 0.2948, 0.1285],
#[ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)

image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)

(width, height) = image.size
width *= 8
height *= 8

dataURL = image_to_dataURL(image, image_format="JPEG")

context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=step,
total_steps=total_steps,
)