Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,40 @@ class Config:
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
type: Literal["latents_output"] = "latents_output"

# Inputs
latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
#fmt: on


def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)

class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
type: Literal["noise_output"] = "noise_output"

# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on

def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)


SAMPLER_NAME_VALUES = Literal[
tuple(list(SCHEDULER_MAP.keys()))
Expand Down Expand Up @@ -122,9 +145,7 @@ def invoke(self, context: InvocationContext) -> NoiseOutput:

name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
return build_noise_output(latents_name=name, latents=noise)


# Text to image
Expand Down Expand Up @@ -240,9 +261,7 @@ def step_callback(state: PipelineIntermediateState):

name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
return build_latents_output(latents_name=name, latents=result_latents)


class LatentsToLatentsInvocation(TextToLatentsInvocation):
Expand Down Expand Up @@ -301,9 +320,7 @@ def step_callback(state: PipelineIntermediateState):

name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
return build_latents_output(latents_name=name, latents=result_latents)


# Latent to image
Expand Down Expand Up @@ -367,11 +384,11 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize"

# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")

def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
Expand All @@ -388,7 +405,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=resized_latents)


class ScaleLatentsInvocation(BaseInvocation):
Expand All @@ -397,10 +414,10 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale"

# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")

def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
Expand All @@ -418,7 +435,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=resized_latents)


class ImageToLatentsInvocation(BaseInvocation):
Expand Down Expand Up @@ -462,4 +479,4 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=latents)