Skip to content

Commit 07a2da4

Browse files
authored
Rewrite controlnet to new model manager (#3665)
2 parents 5d5a497 + f7230d0 commit 07a2da4

File tree

58 files changed

+1091
-621
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1091
-621
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from PIL import Image
1010
from pydantic import BaseModel, Field, validator
1111

12+
from ...backend.model_management import BaseModelType, ModelType
1213
from ..models.image import ImageField, ImageCategory, ResourceOrigin
1314
from .baseinvocation import (
1415
BaseInvocation,
@@ -105,9 +106,15 @@
105106
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
106107

107108

109+
class ControlNetModelField(BaseModel):
110+
"""ControlNet model field"""
111+
112+
model_name: str = Field(description="Name of the ControlNet model")
113+
base_model: BaseModelType = Field(description="Base model")
114+
108115
class ControlField(BaseModel):
109116
image: ImageField = Field(default=None, description="The control image")
110-
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
117+
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
111118
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
112119
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
113120
begin_step_percent: float = Field(default=0, ge=0, le=1,
@@ -118,22 +125,23 @@ class ControlField(BaseModel):
118125
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
119126

120127
@validator("control_weight")
121-
def abs_le_one(cls, v):
122-
"""validate that all abs(values) are <=1"""
128+
def validate_control_weight(cls, v):
129+
"""Validate that all control weights in the valid range"""
123130
if isinstance(v, list):
124131
for i in v:
125-
if abs(i) > 1:
126-
raise ValueError('all abs(control_weight) must be <= 1')
132+
if i < -1 or i > 2:
133+
raise ValueError('Control weights must be within -1 to 2 range')
127134
else:
128-
if abs(v) > 1:
129-
raise ValueError('abs(control_weight) must be <= 1')
135+
if v < -1 or v > 2:
136+
raise ValueError('Control weights must be within -1 to 2 range')
130137
return v
131138
class Config:
132139
schema_extra = {
133140
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
134141
"ui": {
135142
"type_hints": {
136143
"control_weight": "float",
144+
"control_model": "controlnet_model",
137145
# "control_weight": "number",
138146
}
139147
}
@@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
154162
type: Literal["controlnet"] = "controlnet"
155163
# Inputs
156164
image: ImageField = Field(default=None, description="The control image")
157-
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
165+
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
158166
description="control model used")
159167
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
160-
begin_step_percent: float = Field(default=0, ge=0, le=1,
168+
begin_step_percent: float = Field(default=0, ge=-1, le=2,
161169
description="When the ControlNet is first applied (% of total steps)")
162170
end_step_percent: float = Field(default=1, ge=0, le=1,
163171
description="When the ControlNet is last applied (% of total steps)")

invokeai/app/invocations/latent.py

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
22

3+
from contextlib import ExitStack
34
from typing import List, Literal, Optional, Union
45

56
import einops
@@ -11,6 +12,7 @@
1112

1213
from invokeai.app.invocations.metadata import CoreMetadata
1314
from invokeai.app.util.step_callback import stable_diffusion_step_callback
15+
from invokeai.backend.model_management.models.base import ModelType
1416

1517
from ...backend.model_management.lora import ModelPatcher
1618
from ...backend.stable_diffusion import PipelineIntermediateState
@@ -71,16 +73,21 @@ def get_scheduler(
7173
scheduler_name: str,
7274
) -> Scheduler:
7375
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
74-
scheduler_name, SCHEDULER_MAP['ddim'])
76+
scheduler_name, SCHEDULER_MAP['ddim']
77+
)
7578
orig_scheduler_info = context.services.model_manager.get_model(
76-
**scheduler_info.dict())
79+
**scheduler_info.dict()
80+
)
7781
with orig_scheduler_info as orig_scheduler:
7882
scheduler_config = orig_scheduler.config
7983

8084
if "_backup" in scheduler_config:
8185
scheduler_config = scheduler_config["_backup"]
82-
scheduler_config = {**scheduler_config, **
83-
scheduler_extra_config, "_backup": scheduler_config}
86+
scheduler_config = {
87+
**scheduler_config,
88+
**scheduler_extra_config,
89+
"_backup": scheduler_config,
90+
}
8491
scheduler = scheduler_class.from_config(scheduler_config)
8592

8693
# hack copied over from generate.py
@@ -137,8 +144,11 @@ class Config(InvocationConfig):
137144

138145
# TODO: pass this an emitter method or something? or a session for dispatching?
139146
def dispatch_progress(
140-
self, context: InvocationContext, source_node_id: str,
141-
intermediate_state: PipelineIntermediateState) -> None:
147+
self,
148+
context: InvocationContext,
149+
source_node_id: str,
150+
intermediate_state: PipelineIntermediateState,
151+
) -> None:
142152
stable_diffusion_step_callback(
143153
context=context,
144154
intermediate_state=intermediate_state,
@@ -147,11 +157,16 @@ def dispatch_progress(
147157
)
148158

149159
def get_conditioning_data(
150-
self, context: InvocationContext, scheduler) -> ConditioningData:
160+
self,
161+
context: InvocationContext,
162+
scheduler,
163+
) -> ConditioningData:
151164
c, extra_conditioning_info = context.services.latents.get(
152-
self.positive_conditioning.conditioning_name)
165+
self.positive_conditioning.conditioning_name
166+
)
153167
uc, _ = context.services.latents.get(
154-
self.negative_conditioning.conditioning_name)
168+
self.negative_conditioning.conditioning_name
169+
)
155170

156171
conditioning_data = ConditioningData(
157172
unconditioned_embeddings=uc,
@@ -178,7 +193,10 @@ def get_conditioning_data(
178193
return conditioning_data
179194

180195
def create_pipeline(
181-
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
196+
self,
197+
unet,
198+
scheduler,
199+
) -> StableDiffusionGeneratorPipeline:
182200
# TODO:
183201
# configure_model_padding(
184202
# unet,
@@ -213,6 +231,7 @@ def prep_control_data(
213231
model: StableDiffusionGeneratorPipeline,
214232
control_input: List[ControlField],
215233
latents_shape: List[int],
234+
exit_stack: ExitStack,
216235
do_classifier_free_guidance: bool = True,
217236
) -> List[ControlNetData]:
218237

@@ -238,25 +257,19 @@ def prep_control_data(
238257
control_data = []
239258
control_models = []
240259
for control_info in control_list:
241-
# handle control models
242-
if ("," in control_info.control_model):
243-
control_model_split = control_info.control_model.split(",")
244-
control_name = control_model_split[0]
245-
control_subfolder = control_model_split[1]
246-
print("Using HF model subfolders")
247-
print(" control_name: ", control_name)
248-
print(" control_subfolder: ", control_subfolder)
249-
control_model = ControlNetModel.from_pretrained(
250-
control_name, subfolder=control_subfolder,
251-
torch_dtype=model.unet.dtype).to(
252-
model.device)
253-
else:
254-
control_model = ControlNetModel.from_pretrained(
255-
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
260+
control_model = exit_stack.enter_context(
261+
context.services.model_manager.get_model(
262+
model_name=control_info.control_model.model_name,
263+
model_type=ModelType.ControlNet,
264+
base_model=control_info.control_model.base_model,
265+
)
266+
)
267+
256268
control_models.append(control_model)
257269
control_image_field = control_info.image
258270
input_image = context.services.images.get_pil_image(
259-
control_image_field.image_name)
271+
control_image_field.image_name
272+
)
260273
# self.image.image_type, self.image.image_name
261274
# FIXME: still need to test with different widths, heights, devices, dtypes
262275
# and add in batch_size, num_images_per_prompt?
@@ -278,7 +291,8 @@ def prep_control_data(
278291
weight=control_info.control_weight,
279292
begin_step_percent=control_info.begin_step_percent,
280293
end_step_percent=control_info.end_step_percent,
281-
control_mode=control_info.control_mode,)
294+
control_mode=control_info.control_mode,
295+
)
282296
control_data.append(control_item)
283297
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
284298
return control_data
@@ -289,7 +303,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
289303

290304
# Get the source node id (we are invoking the prepared node)
291305
graph_execution_state = context.services.graph_execution_manager.get(
292-
context.graph_execution_state_id)
306+
context.graph_execution_state_id
307+
)
293308
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
294309

295310
def step_callback(state: PipelineIntermediateState):
@@ -298,14 +313,17 @@ def step_callback(state: PipelineIntermediateState):
298313
def _lora_loader():
299314
for lora in self.unet.loras:
300315
lora_info = context.services.model_manager.get_model(
301-
**lora.dict(exclude={"weight"}))
316+
**lora.dict(exclude={"weight"})
317+
)
302318
yield (lora_info.context.model, lora.weight)
303319
del lora_info
304320
return
305321

306322
unet_info = context.services.model_manager.get_model(
307-
**self.unet.unet.dict())
308-
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
323+
**self.unet.unet.dict()
324+
)
325+
with ExitStack() as exit_stack,\
326+
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
309327
unet_info as unet:
310328

311329
scheduler = get_scheduler(
@@ -322,6 +340,7 @@ def _lora_loader():
322340
latents_shape=noise.shape,
323341
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
324342
do_classifier_free_guidance=True,
343+
exit_stack=exit_stack,
325344
)
326345

327346
# TODO: Verify the noise is the right size
@@ -374,7 +393,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
374393

375394
# Get the source node id (we are invoking the prepared node)
376395
graph_execution_state = context.services.graph_execution_manager.get(
377-
context.graph_execution_state_id)
396+
context.graph_execution_state_id
397+
)
378398
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
379399

380400
def step_callback(state: PipelineIntermediateState):
@@ -383,14 +403,17 @@ def step_callback(state: PipelineIntermediateState):
383403
def _lora_loader():
384404
for lora in self.unet.loras:
385405
lora_info = context.services.model_manager.get_model(
386-
**lora.dict(exclude={"weight"}))
406+
**lora.dict(exclude={"weight"})
407+
)
387408
yield (lora_info.context.model, lora.weight)
388409
del lora_info
389410
return
390411

391412
unet_info = context.services.model_manager.get_model(
392-
**self.unet.unet.dict())
393-
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
413+
**self.unet.unet.dict()
414+
)
415+
with ExitStack() as exit_stack,\
416+
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
394417
unet_info as unet:
395418

396419
scheduler = get_scheduler(
@@ -407,11 +430,13 @@ def _lora_loader():
407430
latents_shape=noise.shape,
408431
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
409432
do_classifier_free_guidance=True,
433+
exit_stack=exit_stack,
410434
)
411435

412436
# TODO: Verify the noise is the right size
413437
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
414-
latent, device=unet.device, dtype=latent.dtype)
438+
latent, device=unet.device, dtype=latent.dtype
439+
)
415440

416441
timesteps, _ = pipeline.get_img2img_timesteps(
417442
self.steps,
@@ -535,7 +560,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
535560
resized_latents = torch.nn.functional.interpolate(
536561
latents, size=(self.height // 8, self.width // 8),
537562
mode=self.mode, antialias=self.antialias
538-
if self.mode in ["bilinear", "bicubic"] else False,)
563+
if self.mode in ["bilinear", "bicubic"] else False,
564+
)
539565

540566
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
541567
torch.cuda.empty_cache()
@@ -569,7 +595,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
569595
resized_latents = torch.nn.functional.interpolate(
570596
latents, scale_factor=self.scale_factor, mode=self.mode,
571597
antialias=self.antialias
572-
if self.mode in ["bilinear", "bicubic"] else False,)
598+
if self.mode in ["bilinear", "bicubic"] else False,
599+
)
573600

574601
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
575602
torch.cuda.empty_cache()

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
1313

1414
const moduleLog = log.child({ namespace: 'controlNet' });
1515

16-
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
16+
const predicate: AnyListenerPredicate<RootState> = (
17+
action,
18+
state,
19+
prevState
20+
) => {
1721
const isActionMatched =
1822
controlNetProcessorParamsChanged.match(action) ||
1923
controlNetModelChanged.match(action) ||
@@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
2529
return false;
2630
}
2731

32+
if (controlNetAutoConfigToggled.match(action)) {
33+
// do not process if the user just disabled auto-config
34+
if (
35+
prevState.controlNet.controlNets[action.payload.controlNetId]
36+
.shouldAutoConfig === true
37+
) {
38+
return false;
39+
}
40+
}
41+
2842
const { controlImage, processorType, shouldAutoConfig } =
2943
state.controlNet.controlNets[action.payload.controlNetId];
3044

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
1010
import { addToast } from 'features/system/store/systemSlice';
1111
import { forEach } from 'lodash-es';
1212
import { startAppListening } from '..';
13+
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
1314

1415
const moduleLog = log.child({ module: 'models' });
1516

@@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
5152
modelsCleared += 1;
5253
}
5354

54-
// TODO: handle incompatible controlnet; pending model manager support
55+
const { controlNets } = state.controlNet;
56+
forEach(controlNets, (controlNet, controlNetId) => {
57+
if (controlNet.model?.base_model !== base_model) {
58+
dispatch(controlNetRemoved({ controlNetId }));
59+
modelsCleared += 1;
60+
}
61+
});
62+
5563
if (modelsCleared > 0) {
5664
dispatch(
5765
addToast(

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
import { forEach, some } from 'lodash-es';
1212
import { modelsApi } from 'services/api/endpoints/models';
1313
import { startAppListening } from '..';
14+
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
1415

1516
const moduleLog = log.child({ module: 'models' });
1617

@@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
127128
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
128129
effect: async (action, { getState, dispatch }) => {
129130
// ControlNet models loaded - need to remove missing ControlNets from state
130-
// TODO: pending model manager controlnet support
131+
const controlNets = getState().controlNet.controlNets;
132+
133+
forEach(controlNets, (controlNet, controlNetId) => {
134+
const isControlNetAvailable = some(
135+
action.payload.entities,
136+
(m) =>
137+
m?.model_name === controlNet?.model?.model_name &&
138+
m?.base_model === controlNet?.model?.base_model
139+
);
140+
141+
if (isControlNetAvailable) {
142+
return;
143+
}
144+
145+
dispatch(controlNetRemoved({ controlNetId }));
146+
});
131147
},
132148
});
133149
};

0 commit comments

Comments
 (0)