Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 108 additions & 176 deletions invokeai/app/invocations/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,127 +12,68 @@
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods

from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput

import re
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator

SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)


class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"model": "model",
},
},
}


# Text to image
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
"""Generates an image using text2img."""

type: Literal["txt2img"] = "txt2img"

# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on
from .latent import get_scheduler

# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
class OldModelContext(ContextDecorator):
model: StableDiffusionGeneratorPipeline

def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)
def __init__(self, model):
self.model = model

# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(self.control_image.image_name)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
def __enter__(self):
return self.model

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def __exit__(self, *exc):
return False

txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict(
exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
class OldModelInfo:
name: str
hash: str
context: OldModelContext

image_dto = context.services.images.create(
image=generate_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
self.name = name
self.hash = hash
self.context = OldModelContext(
model=model,
)

return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

class InpaintInvocation(BaseInvocation):
"""Generates an image using inpaint."""

class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal["inpaint"] = "inpaint"

type: Literal["img2img"] = "img2img"
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet model")
vae: VaeField = Field(default=None, description="Vae model")

# Inputs
image: Union[ImageField, None] = Field(description="The input image")
Expand All @@ -144,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image",
)

def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)

def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get_pil_image(self.image.image_name)
)

if self.fit:
image = image.resize((self.width, self.height))

# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)

# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)

image_dto = context.services.images.create(
image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
)

return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)


class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""

type: Literal["inpaint"] = "inpaint"

# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
Expand Down Expand Up @@ -252,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise",
)

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
},
}

def dispatch_progress(
self,
context: InvocationContext,
Expand All @@ -265,6 +148,49 @@ def dispatch_progress(
source_node_id=source_node_id,
)

def get_conditioning(self, context):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)

return (uc, c, extra_conditioning_info)

@contextmanager
def load_model_old_way(self, context, scheduler):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())

#unet = unet_info.context.model
#vae = vae_info.context.model

with ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]

with vae_info as vae,\
unet_info as unet,\
ModelPatcher.apply_lora_unet(unet, loras):

device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision

pipeline = StableDiffusionGeneratorPipeline(
vae=vae,
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)

yield OldModelInfo(
name=self.unet.unet.model_name,
hash="<NO-HASH>",
model=pipeline,
)

def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
Expand All @@ -277,24 +203,30 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
else context.services.images.get_pil_image(self.mask.image_name)
)

# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

outputs = Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
conditioning = self.get_conditioning(context)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)

with self.load_model_old_way(context, scheduler) as model:
outputs = Inpaint(model).generate(
conditioning=conditioning,
scheduler=scheduler,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)

# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
Expand Down
1 change: 0 additions & 1 deletion invokeai/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint
)
Expand Down
1 change: 0 additions & 1 deletion invokeai/backend/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
InvokeAIGenerator,
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint,
Generator,
Expand Down
Loading