1212from invokeai .app .util .misc import SEED_MAX , get_random_seed
1313from invokeai .backend .generator .inpaint import infill_methods
1414
15- from ...backend .generator import Img2Img , Inpaint , InvokeAIGenerator , Txt2Img
15+ from ...backend .generator import Inpaint , InvokeAIGenerator
1616from ...backend .stable_diffusion import PipelineIntermediateState
1717from ..util .step_callback import stable_diffusion_step_callback
1818from .baseinvocation import BaseInvocation , InvocationConfig , InvocationContext
1919from .image import ImageOutput
2020
21+ import re
22+ from ...backend .model_management .lora import ModelPatcher
23+ from ...backend .stable_diffusion .diffusers_pipeline import StableDiffusionGeneratorPipeline
24+ from .model import UNetField , VaeField
25+ from .compel import ConditioningField
26+ from contextlib import contextmanager , ExitStack , ContextDecorator
27+
2128SAMPLER_NAME_VALUES = Literal [tuple (InvokeAIGenerator .schedulers ())]
2229INFILL_METHODS = Literal [tuple (infill_methods ())]
2330DEFAULT_INFILL_METHOD = (
2431 "patchmatch" if "patchmatch" in get_args (INFILL_METHODS ) else "tile"
2532)
2633
2734
28- class SDImageInvocation (BaseModel ):
29- """Helper class to provide all Stable Diffusion raster image invocations with additional config"""
30-
31- # Schema customisation
32- class Config (InvocationConfig ):
33- schema_extra = {
34- "ui" : {
35- "tags" : ["stable-diffusion" , "image" ],
36- "type_hints" : {
37- "model" : "model" ,
38- },
39- },
40- }
41-
42-
43- # Text to image
44- class TextToImageInvocation (BaseInvocation , SDImageInvocation ):
45- """Generates an image using text2img."""
46-
47- type : Literal ["txt2img" ] = "txt2img"
48-
49- # Inputs
50- # TODO: consider making prompt optional to enable providing prompt through a link
51- # fmt: off
52- prompt : Optional [str ] = Field (description = "The prompt to generate an image from" )
53- seed : int = Field (ge = 0 , le = SEED_MAX , description = "The seed to use (omit for random)" , default_factory = get_random_seed )
54- steps : int = Field (default = 30 , gt = 0 , description = "The number of steps to use to generate the image" )
55- width : int = Field (default = 512 , multiple_of = 8 , gt = 0 , description = "The width of the resulting image" , )
56- height : int = Field (default = 512 , multiple_of = 8 , gt = 0 , description = "The height of the resulting image" , )
57- cfg_scale : float = Field (default = 7.5 , ge = 1 , description = "The Classifier-Free Guidance, higher values may result in a result closer to the prompt" , )
58- scheduler : SAMPLER_NAME_VALUES = Field (default = "euler" , description = "The scheduler to use" )
59- model : str = Field (default = "" , description = "The model to use (currently ignored)" )
60- progress_images : bool = Field (default = False , description = "Whether or not to produce progress images during generation" , )
61- control_model : Optional [str ] = Field (default = None , description = "The control model to use" )
62- control_image : Optional [ImageField ] = Field (default = None , description = "The processed control image" )
63- # fmt: on
35+ from .latent import get_scheduler
6436
65- # TODO: pass this an emitter method or something? or a session for dispatching?
66- def dispatch_progress (
67- self ,
68- context : InvocationContext ,
69- source_node_id : str ,
70- intermediate_state : PipelineIntermediateState ,
71- ) -> None :
72- stable_diffusion_step_callback (
73- context = context ,
74- intermediate_state = intermediate_state ,
75- node = self .dict (),
76- source_node_id = source_node_id ,
77- )
37+ class OldModelContext (ContextDecorator ):
38+ model : StableDiffusionGeneratorPipeline
7839
79- def invoke (self , context : InvocationContext ) -> ImageOutput :
80- # Handle invalid model parameter
81- model = context .services .model_manager .get_model (self .model ,node = self ,context = context )
40+ def __init__ (self , model ):
41+ self .model = model
8242
83- # loading controlnet image (currently requires pre-processed image)
84- control_image = (
85- None if self .control_image is None
86- else context .services .images .get_pil_image (self .control_image .image_name )
87- )
88- # loading controlnet model
89- if (self .control_model is None or self .control_model == '' ):
90- control_model = None
91- else :
92- # FIXME: change this to dropdown menu?
93- # FIXME: generalize so don't have to hardcode torch_dtype and device
94- control_model = ControlNetModel .from_pretrained (self .control_model ,
95- torch_dtype = torch .float16 ).to ("cuda" )
43+ def __enter__ (self ):
44+ return self .model
9645
97- # Get the source node id (we are invoking the prepared node)
98- graph_execution_state = context .services .graph_execution_manager .get (
99- context .graph_execution_state_id
100- )
101- source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
46+ def __exit__ (self , * exc ):
47+ return False
10248
103- txt2img = Txt2Img (model , control_model = control_model )
104- outputs = txt2img .generate (
105- prompt = self .prompt ,
106- step_callback = partial (self .dispatch_progress , context , source_node_id ),
107- control_image = control_image ,
108- ** self .dict (
109- exclude = {"prompt" , "control_image" }
110- ), # Shorthand for passing all of the parameters above manually
111- )
112- # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
113- # each time it is called. We only need the first one.
114- generate_output = next (outputs )
49+ class OldModelInfo :
50+ name : str
51+ hash : str
52+ context : OldModelContext
11553
116- image_dto = context .services .images .create (
117- image = generate_output .image ,
118- image_origin = ResourceOrigin .INTERNAL ,
119- image_category = ImageCategory .GENERAL ,
120- session_id = context .graph_execution_state_id ,
121- node_id = self .id ,
122- is_intermediate = self .is_intermediate ,
54+ def __init__ (self , name : str , hash : str , model : StableDiffusionGeneratorPipeline ):
55+ self .name = name
56+ self .hash = hash
57+ self .context = OldModelContext (
58+ model = model ,
12359 )
12460
125- return ImageOutput (
126- image = ImageField (image_name = image_dto .image_name ),
127- width = image_dto .width ,
128- height = image_dto .height ,
129- )
13061
62+ class InpaintInvocation (BaseInvocation ):
63+ """Generates an image using inpaint."""
13164
132- class ImageToImageInvocation (TextToImageInvocation ):
133- """Generates an image using img2img."""
65+ type : Literal ["inpaint" ] = "inpaint"
13466
135- type : Literal ["img2img" ] = "img2img"
67+ positive_conditioning : Optional [ConditioningField ] = Field (description = "Positive conditioning for generation" )
68+ negative_conditioning : Optional [ConditioningField ] = Field (description = "Negative conditioning for generation" )
69+ seed : int = Field (ge = 0 , le = SEED_MAX , description = "The seed to use (omit for random)" , default_factory = get_random_seed )
70+ steps : int = Field (default = 30 , gt = 0 , description = "The number of steps to use to generate the image" )
71+ width : int = Field (default = 512 , multiple_of = 8 , gt = 0 , description = "The width of the resulting image" , )
72+ height : int = Field (default = 512 , multiple_of = 8 , gt = 0 , description = "The height of the resulting image" , )
73+ cfg_scale : float = Field (default = 7.5 , ge = 1 , description = "The Classifier-Free Guidance, higher values may result in a result closer to the prompt" , )
74+ scheduler : SAMPLER_NAME_VALUES = Field (default = "euler" , description = "The scheduler to use" )
75+ unet : UNetField = Field (default = None , description = "UNet model" )
76+ vae : VaeField = Field (default = None , description = "Vae model" )
13677
13778 # Inputs
13879 image : Union [ImageField , None ] = Field (description = "The input image" )
@@ -144,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
14485 description = "Whether or not the result should be fit to the aspect ratio of the input image" ,
14586 )
14687
147- def dispatch_progress (
148- self ,
149- context : InvocationContext ,
150- source_node_id : str ,
151- intermediate_state : PipelineIntermediateState ,
152- ) -> None :
153- stable_diffusion_step_callback (
154- context = context ,
155- intermediate_state = intermediate_state ,
156- node = self .dict (),
157- source_node_id = source_node_id ,
158- )
159-
160- def invoke (self , context : InvocationContext ) -> ImageOutput :
161- image = (
162- None
163- if self .image is None
164- else context .services .images .get_pil_image (self .image .image_name )
165- )
166-
167- if self .fit :
168- image = image .resize ((self .width , self .height ))
169-
170- # Handle invalid model parameter
171- model = context .services .model_manager .get_model (self .model ,node = self ,context = context )
172-
173- # Get the source node id (we are invoking the prepared node)
174- graph_execution_state = context .services .graph_execution_manager .get (
175- context .graph_execution_state_id
176- )
177- source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
178-
179- outputs = Img2Img (model ).generate (
180- prompt = self .prompt ,
181- init_image = image ,
182- step_callback = partial (self .dispatch_progress , context , source_node_id ),
183- ** self .dict (
184- exclude = {"prompt" , "image" , "mask" }
185- ), # Shorthand for passing all of the parameters above manually
186- )
187-
188- # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
189- # each time it is called. We only need the first one.
190- generator_output = next (outputs )
191-
192- image_dto = context .services .images .create (
193- image = generator_output .image ,
194- image_origin = ResourceOrigin .INTERNAL ,
195- image_category = ImageCategory .GENERAL ,
196- session_id = context .graph_execution_state_id ,
197- node_id = self .id ,
198- is_intermediate = self .is_intermediate ,
199- )
200-
201- return ImageOutput (
202- image = ImageField (image_name = image_dto .image_name ),
203- width = image_dto .width ,
204- height = image_dto .height ,
205- )
206-
207-
208- class InpaintInvocation (ImageToImageInvocation ):
209- """Generates an image using inpaint."""
210-
211- type : Literal ["inpaint" ] = "inpaint"
212-
21388 # Inputs
21489 mask : Union [ImageField , None ] = Field (description = "The mask" )
21590 seam_size : int = Field (default = 96 , ge = 1 , description = "The seam inpaint size (px)" )
@@ -252,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
252127 description = "The amount by which to replace masked areas with latent noise" ,
253128 )
254129
130+ # Schema customisation
131+ class Config (InvocationConfig ):
132+ schema_extra = {
133+ "ui" : {
134+ "tags" : ["stable-diffusion" , "image" ],
135+ },
136+ }
137+
255138 def dispatch_progress (
256139 self ,
257140 context : InvocationContext ,
@@ -265,6 +148,49 @@ def dispatch_progress(
265148 source_node_id = source_node_id ,
266149 )
267150
151+ def get_conditioning (self , context ):
152+ c , extra_conditioning_info = context .services .latents .get (self .positive_conditioning .conditioning_name )
153+ uc , _ = context .services .latents .get (self .negative_conditioning .conditioning_name )
154+
155+ return (uc , c , extra_conditioning_info )
156+
157+ @contextmanager
158+ def load_model_old_way (self , context , scheduler ):
159+ unet_info = context .services .model_manager .get_model (** self .unet .unet .dict ())
160+ vae_info = context .services .model_manager .get_model (** self .vae .vae .dict ())
161+
162+ #unet = unet_info.context.model
163+ #vae = vae_info.context.model
164+
165+ with ExitStack () as stack :
166+ loras = [(stack .enter_context (context .services .model_manager .get_model (** lora .dict (exclude = {"weight" }))), lora .weight ) for lora in self .unet .loras ]
167+
168+ with vae_info as vae ,\
169+ unet_info as unet ,\
170+ ModelPatcher .apply_lora_unet (unet , loras ):
171+
172+ device = context .services .model_manager .mgr .cache .execution_device
173+ dtype = context .services .model_manager .mgr .cache .precision
174+
175+ pipeline = StableDiffusionGeneratorPipeline (
176+ vae = vae ,
177+ text_encoder = None ,
178+ tokenizer = None ,
179+ unet = unet ,
180+ scheduler = scheduler ,
181+ safety_checker = None ,
182+ feature_extractor = None ,
183+ requires_safety_checker = False ,
184+ precision = "float16" if dtype == torch .float16 else "float32" ,
185+ execution_device = device ,
186+ )
187+
188+ yield OldModelInfo (
189+ name = self .unet .unet .model_name ,
190+ hash = "<NO-HASH>" ,
191+ model = pipeline ,
192+ )
193+
268194 def invoke (self , context : InvocationContext ) -> ImageOutput :
269195 image = (
270196 None
@@ -277,24 +203,30 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
277203 else context .services .images .get_pil_image (self .mask .image_name )
278204 )
279205
280- # Handle invalid model parameter
281- model = context .services .model_manager .get_model (self .model ,node = self ,context = context )
282-
283206 # Get the source node id (we are invoking the prepared node)
284207 graph_execution_state = context .services .graph_execution_manager .get (
285208 context .graph_execution_state_id
286209 )
287210 source_node_id = graph_execution_state .prepared_source_mapping [self .id ]
288211
289- outputs = Inpaint (model ).generate (
290- prompt = self .prompt ,
291- init_image = image ,
292- mask_image = mask ,
293- step_callback = partial (self .dispatch_progress , context , source_node_id ),
294- ** self .dict (
295- exclude = {"prompt" , "image" , "mask" }
296- ), # Shorthand for passing all of the parameters above manually
297- )
212+ conditioning = self .get_conditioning (context )
213+ scheduler = get_scheduler (
214+ context = context ,
215+ scheduler_info = self .unet .scheduler ,
216+ scheduler_name = self .scheduler ,
217+ )
218+
219+ with self .load_model_old_way (context , scheduler ) as model :
220+ outputs = Inpaint (model ).generate (
221+ conditioning = conditioning ,
222+ scheduler = scheduler ,
223+ init_image = image ,
224+ mask_image = mask ,
225+ step_callback = partial (self .dispatch_progress , context , source_node_id ),
226+ ** self .dict (
227+ exclude = {"positive_conditioning" , "negative_conditioning" , "scheduler" , "image" , "mask" }
228+ ), # Shorthand for passing all of the parameters above manually
229+ )
298230
299231 # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
300232 # each time it is called. We only need the first one.
0 commit comments