Skip to content

Commit 1a7fe17

Browse files
Fix inpaint node to new manager (#3550)
Inpaint node still used by canvas, so fixed it to new model manager api. Other old generation code deleted.
2 parents bb2df88 + 4f56930 commit 1a7fe17

File tree

72 files changed

+2143
-3462
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2143
-3462
lines changed

invokeai/app/invocations/generate.py

Lines changed: 108 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -12,127 +12,68 @@
1212
from invokeai.app.util.misc import SEED_MAX, get_random_seed
1313
from invokeai.backend.generator.inpaint import infill_methods
1414

15-
from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
15+
from ...backend.generator import Inpaint, InvokeAIGenerator
1616
from ...backend.stable_diffusion import PipelineIntermediateState
1717
from ..util.step_callback import stable_diffusion_step_callback
1818
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
1919
from .image import ImageOutput
2020

21+
import re
22+
from ...backend.model_management.lora import ModelPatcher
23+
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
24+
from .model import UNetField, VaeField
25+
from .compel import ConditioningField
26+
from contextlib import contextmanager, ExitStack, ContextDecorator
27+
2128
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
2229
INFILL_METHODS = Literal[tuple(infill_methods())]
2330
DEFAULT_INFILL_METHOD = (
2431
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
2532
)
2633

2734

28-
class SDImageInvocation(BaseModel):
29-
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
30-
31-
# Schema customisation
32-
class Config(InvocationConfig):
33-
schema_extra = {
34-
"ui": {
35-
"tags": ["stable-diffusion", "image"],
36-
"type_hints": {
37-
"model": "model",
38-
},
39-
},
40-
}
41-
42-
43-
# Text to image
44-
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
45-
"""Generates an image using text2img."""
46-
47-
type: Literal["txt2img"] = "txt2img"
48-
49-
# Inputs
50-
# TODO: consider making prompt optional to enable providing prompt through a link
51-
# fmt: off
52-
prompt: Optional[str] = Field(description="The prompt to generate an image from")
53-
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
54-
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
55-
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
56-
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
57-
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
58-
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
59-
model: str = Field(default="", description="The model to use (currently ignored)")
60-
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
61-
control_model: Optional[str] = Field(default=None, description="The control model to use")
62-
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
63-
# fmt: on
35+
from .latent import get_scheduler
6436

65-
# TODO: pass this an emitter method or something? or a session for dispatching?
66-
def dispatch_progress(
67-
self,
68-
context: InvocationContext,
69-
source_node_id: str,
70-
intermediate_state: PipelineIntermediateState,
71-
) -> None:
72-
stable_diffusion_step_callback(
73-
context=context,
74-
intermediate_state=intermediate_state,
75-
node=self.dict(),
76-
source_node_id=source_node_id,
77-
)
37+
class OldModelContext(ContextDecorator):
38+
model: StableDiffusionGeneratorPipeline
7839

79-
def invoke(self, context: InvocationContext) -> ImageOutput:
80-
# Handle invalid model parameter
81-
model = context.services.model_manager.get_model(self.model,node=self,context=context)
40+
def __init__(self, model):
41+
self.model = model
8242

83-
# loading controlnet image (currently requires pre-processed image)
84-
control_image = (
85-
None if self.control_image is None
86-
else context.services.images.get_pil_image(self.control_image.image_name)
87-
)
88-
# loading controlnet model
89-
if (self.control_model is None or self.control_model==''):
90-
control_model = None
91-
else:
92-
# FIXME: change this to dropdown menu?
93-
# FIXME: generalize so don't have to hardcode torch_dtype and device
94-
control_model = ControlNetModel.from_pretrained(self.control_model,
95-
torch_dtype=torch.float16).to("cuda")
43+
def __enter__(self):
44+
return self.model
9645

97-
# Get the source node id (we are invoking the prepared node)
98-
graph_execution_state = context.services.graph_execution_manager.get(
99-
context.graph_execution_state_id
100-
)
101-
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
46+
def __exit__(self, *exc):
47+
return False
10248

103-
txt2img = Txt2Img(model, control_model=control_model)
104-
outputs = txt2img.generate(
105-
prompt=self.prompt,
106-
step_callback=partial(self.dispatch_progress, context, source_node_id),
107-
control_image=control_image,
108-
**self.dict(
109-
exclude={"prompt", "control_image" }
110-
), # Shorthand for passing all of the parameters above manually
111-
)
112-
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
113-
# each time it is called. We only need the first one.
114-
generate_output = next(outputs)
49+
class OldModelInfo:
50+
name: str
51+
hash: str
52+
context: OldModelContext
11553

116-
image_dto = context.services.images.create(
117-
image=generate_output.image,
118-
image_origin=ResourceOrigin.INTERNAL,
119-
image_category=ImageCategory.GENERAL,
120-
session_id=context.graph_execution_state_id,
121-
node_id=self.id,
122-
is_intermediate=self.is_intermediate,
54+
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
55+
self.name = name
56+
self.hash = hash
57+
self.context = OldModelContext(
58+
model=model,
12359
)
12460

125-
return ImageOutput(
126-
image=ImageField(image_name=image_dto.image_name),
127-
width=image_dto.width,
128-
height=image_dto.height,
129-
)
13061

62+
class InpaintInvocation(BaseInvocation):
63+
"""Generates an image using inpaint."""
13164

132-
class ImageToImageInvocation(TextToImageInvocation):
133-
"""Generates an image using img2img."""
65+
type: Literal["inpaint"] = "inpaint"
13466

135-
type: Literal["img2img"] = "img2img"
67+
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
68+
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
69+
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
70+
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
71+
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
72+
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
73+
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
74+
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
75+
unet: UNetField = Field(default=None, description="UNet model")
76+
vae: VaeField = Field(default=None, description="Vae model")
13677

13778
# Inputs
13879
image: Union[ImageField, None] = Field(description="The input image")
@@ -144,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
14485
description="Whether or not the result should be fit to the aspect ratio of the input image",
14586
)
14687

147-
def dispatch_progress(
148-
self,
149-
context: InvocationContext,
150-
source_node_id: str,
151-
intermediate_state: PipelineIntermediateState,
152-
) -> None:
153-
stable_diffusion_step_callback(
154-
context=context,
155-
intermediate_state=intermediate_state,
156-
node=self.dict(),
157-
source_node_id=source_node_id,
158-
)
159-
160-
def invoke(self, context: InvocationContext) -> ImageOutput:
161-
image = (
162-
None
163-
if self.image is None
164-
else context.services.images.get_pil_image(self.image.image_name)
165-
)
166-
167-
if self.fit:
168-
image = image.resize((self.width, self.height))
169-
170-
# Handle invalid model parameter
171-
model = context.services.model_manager.get_model(self.model,node=self,context=context)
172-
173-
# Get the source node id (we are invoking the prepared node)
174-
graph_execution_state = context.services.graph_execution_manager.get(
175-
context.graph_execution_state_id
176-
)
177-
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
178-
179-
outputs = Img2Img(model).generate(
180-
prompt=self.prompt,
181-
init_image=image,
182-
step_callback=partial(self.dispatch_progress, context, source_node_id),
183-
**self.dict(
184-
exclude={"prompt", "image", "mask"}
185-
), # Shorthand for passing all of the parameters above manually
186-
)
187-
188-
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
189-
# each time it is called. We only need the first one.
190-
generator_output = next(outputs)
191-
192-
image_dto = context.services.images.create(
193-
image=generator_output.image,
194-
image_origin=ResourceOrigin.INTERNAL,
195-
image_category=ImageCategory.GENERAL,
196-
session_id=context.graph_execution_state_id,
197-
node_id=self.id,
198-
is_intermediate=self.is_intermediate,
199-
)
200-
201-
return ImageOutput(
202-
image=ImageField(image_name=image_dto.image_name),
203-
width=image_dto.width,
204-
height=image_dto.height,
205-
)
206-
207-
208-
class InpaintInvocation(ImageToImageInvocation):
209-
"""Generates an image using inpaint."""
210-
211-
type: Literal["inpaint"] = "inpaint"
212-
21388
# Inputs
21489
mask: Union[ImageField, None] = Field(description="The mask")
21590
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
@@ -252,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
252127
description="The amount by which to replace masked areas with latent noise",
253128
)
254129

130+
# Schema customisation
131+
class Config(InvocationConfig):
132+
schema_extra = {
133+
"ui": {
134+
"tags": ["stable-diffusion", "image"],
135+
},
136+
}
137+
255138
def dispatch_progress(
256139
self,
257140
context: InvocationContext,
@@ -265,6 +148,49 @@ def dispatch_progress(
265148
source_node_id=source_node_id,
266149
)
267150

151+
def get_conditioning(self, context):
152+
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
153+
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
154+
155+
return (uc, c, extra_conditioning_info)
156+
157+
@contextmanager
158+
def load_model_old_way(self, context, scheduler):
159+
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
160+
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
161+
162+
#unet = unet_info.context.model
163+
#vae = vae_info.context.model
164+
165+
with ExitStack() as stack:
166+
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
167+
168+
with vae_info as vae,\
169+
unet_info as unet,\
170+
ModelPatcher.apply_lora_unet(unet, loras):
171+
172+
device = context.services.model_manager.mgr.cache.execution_device
173+
dtype = context.services.model_manager.mgr.cache.precision
174+
175+
pipeline = StableDiffusionGeneratorPipeline(
176+
vae=vae,
177+
text_encoder=None,
178+
tokenizer=None,
179+
unet=unet,
180+
scheduler=scheduler,
181+
safety_checker=None,
182+
feature_extractor=None,
183+
requires_safety_checker=False,
184+
precision="float16" if dtype == torch.float16 else "float32",
185+
execution_device=device,
186+
)
187+
188+
yield OldModelInfo(
189+
name=self.unet.unet.model_name,
190+
hash="<NO-HASH>",
191+
model=pipeline,
192+
)
193+
268194
def invoke(self, context: InvocationContext) -> ImageOutput:
269195
image = (
270196
None
@@ -277,24 +203,30 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
277203
else context.services.images.get_pil_image(self.mask.image_name)
278204
)
279205

280-
# Handle invalid model parameter
281-
model = context.services.model_manager.get_model(self.model,node=self,context=context)
282-
283206
# Get the source node id (we are invoking the prepared node)
284207
graph_execution_state = context.services.graph_execution_manager.get(
285208
context.graph_execution_state_id
286209
)
287210
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
288211

289-
outputs = Inpaint(model).generate(
290-
prompt=self.prompt,
291-
init_image=image,
292-
mask_image=mask,
293-
step_callback=partial(self.dispatch_progress, context, source_node_id),
294-
**self.dict(
295-
exclude={"prompt", "image", "mask"}
296-
), # Shorthand for passing all of the parameters above manually
297-
)
212+
conditioning = self.get_conditioning(context)
213+
scheduler = get_scheduler(
214+
context=context,
215+
scheduler_info=self.unet.scheduler,
216+
scheduler_name=self.scheduler,
217+
)
218+
219+
with self.load_model_old_way(context, scheduler) as model:
220+
outputs = Inpaint(model).generate(
221+
conditioning=conditioning,
222+
scheduler=scheduler,
223+
init_image=image,
224+
mask_image=mask,
225+
step_callback=partial(self.dispatch_progress, context, source_node_id),
226+
**self.dict(
227+
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
228+
), # Shorthand for passing all of the parameters above manually
229+
)
298230

299231
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
300232
# each time it is called. We only need the first one.

invokeai/backend/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
InvokeAIGeneratorBasicParams,
66
InvokeAIGenerator,
77
InvokeAIGeneratorOutput,
8-
Txt2Img,
98
Img2Img,
109
Inpaint
1110
)

invokeai/backend/generator/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
InvokeAIGenerator,
66
InvokeAIGeneratorBasicParams,
77
InvokeAIGeneratorOutput,
8-
Txt2Img,
98
Img2Img,
109
Inpaint,
1110
Generator,

0 commit comments

Comments
 (0)