Skip to content

Commit 201b843

Browse files
Feat/controlnet extras (#3596)
Trying to get a few ControlNet extras in before 3.0 release: - SegmentAnything ControlNet preprocessor node - LeResDepth ControlNet preprocessor node (but commented out till controlnet_aux v0.0.6 is released & required by InvokeAI) - TileResampler ControlNet preprocessor node (should be equivalent to Mikubill/sd-webui-controlnet extension tile_resampler) - fix for Midas ControlNet preprocessor error with images that have alpha channel Example usage of SegmentAnything preprocessor node: ![Screenshot from 2023-06-26 16-53-44](https://github.com/invoke-ai/InvokeAI/assets/303100/c6278f9a-5f6b-44bd-98b1-fcaf77251a76)
2 parents 00c78b1 + 32883ad commit 201b843

File tree

2 files changed

+113
-5
lines changed

2 files changed

+113
-5
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
# InvokeAI nodes for ControlNet image preprocessors
1+
# Invocations for ControlNet image preprocessors
22
# initial implementation by Gregg Helt, 2023
33
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
44
from builtins import float, bool
55

6+
import cv2
67
import numpy as np
7-
from typing import Literal, Optional, Union, List
8+
from typing import Literal, Optional, Union, List, Dict
89
from PIL import Image, ImageFilter, ImageOps
910
from pydantic import BaseModel, Field, validator
1011

@@ -29,8 +30,13 @@
2930
ContentShuffleDetector,
3031
ZoeDetector,
3132
MediapipeFaceDetector,
33+
SamDetector,
34+
LeresDetector,
3235
)
3336

37+
from controlnet_aux.util import HWC3, ade_palette
38+
39+
3440
from .image import ImageOutput, PILInvocationConfig
3541

3642
CONTROLNET_DEFAULT_MODELS = [
@@ -95,6 +101,9 @@
95101

96102
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
97103
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
104+
# crop and fill options not ready yet
105+
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
106+
98107

99108
class ControlField(BaseModel):
100109
image: ImageField = Field(default=None, description="The control image")
@@ -105,7 +114,8 @@ class ControlField(BaseModel):
105114
description="When the ControlNet is first applied (% of total steps)")
106115
end_step_percent: float = Field(default=1, ge=0, le=1,
107116
description="When the ControlNet is last applied (% of total steps)")
108-
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
117+
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
118+
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
109119

110120
@validator("control_weight")
111121
def abs_le_one(cls, v):
@@ -180,7 +190,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput:
180190
),
181191
)
182192

183-
# TODO: move image processors to separate file (image_analysis.py
193+
184194
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
185195
"""Base class for invocations that preprocess images for ControlNet"""
186196

@@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
452462
# fmt: on
453463

454464
def run_processor(self, image):
465+
# MediaPipeFaceDetector throws an error if image has alpha channel
466+
# so convert to RGB if needed
467+
if image.mode == 'RGBA':
468+
image = image.convert('RGB')
455469
mediapipe_face_processor = MediapipeFaceDetector()
456470
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
457471
return processed_image
472+
473+
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
474+
"""Applies leres processing to image"""
475+
# fmt: off
476+
type: Literal["leres_image_processor"] = "leres_image_processor"
477+
# Inputs
478+
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
479+
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
480+
boost: bool = Field(default=False, description="Whether to use boost mode")
481+
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
482+
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
483+
# fmt: on
484+
485+
def run_processor(self, image):
486+
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
487+
processed_image = leres_processor(image,
488+
thr_a=self.thr_a,
489+
thr_b=self.thr_b,
490+
boost=self.boost,
491+
detect_resolution=self.detect_resolution,
492+
image_resolution=self.image_resolution)
493+
return processed_image
494+
495+
496+
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
497+
498+
# fmt: off
499+
type: Literal["tile_image_processor"] = "tile_image_processor"
500+
# Inputs
501+
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
502+
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
503+
# fmt: on
504+
505+
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
506+
def tile_resample(self,
507+
np_img: np.ndarray,
508+
res=512, # never used?
509+
down_sampling_rate=1.0,
510+
):
511+
np_img = HWC3(np_img)
512+
if down_sampling_rate < 1.1:
513+
return np_img
514+
H, W, C = np_img.shape
515+
H = int(float(H) / float(down_sampling_rate))
516+
W = int(float(W) / float(down_sampling_rate))
517+
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
518+
return np_img
519+
520+
def run_processor(self, img):
521+
np_img = np.array(img, dtype=np.uint8)
522+
processed_np_image = self.tile_resample(np_img,
523+
#res=self.tile_size,
524+
down_sampling_rate=self.down_sampling_rate
525+
)
526+
processed_image = Image.fromarray(processed_np_image)
527+
return processed_image
528+
529+
530+
531+
532+
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
533+
"""Applies segment anything processing to image"""
534+
# fmt: off
535+
type: Literal["segment_anything_processor"] = "segment_anything_processor"
536+
# fmt: on
537+
538+
def run_processor(self, image):
539+
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
540+
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
541+
np_img = np.array(image, dtype=np.uint8)
542+
processed_image = segment_anything_processor(np_img)
543+
return processed_image
544+
545+
class SamDetectorReproducibleColors(SamDetector):
546+
547+
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
548+
# base class show_anns() method randomizes colors,
549+
# which seems to also lead to non-reproducible image generation
550+
# so using ADE20k color palette instead
551+
def show_anns(self, anns: List[Dict]):
552+
if len(anns) == 0:
553+
return
554+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
555+
h, w = anns[0]['segmentation'].shape
556+
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
557+
palette = ade_palette()
558+
for i, ann in enumerate(sorted_anns):
559+
m = ann['segmentation']
560+
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
561+
# doing modulo just in case number of annotated regions exceeds number of colors in palette
562+
ann_color = palette[i % len(palette)]
563+
img[:, :] = ann_color
564+
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
565+
return np.array(final_img, dtype=np.uint8)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939
"click",
4040
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
4141
"compel>=1.2.1",
42-
"controlnet-aux>=0.0.4",
42+
"controlnet-aux>=0.0.6",
4343
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
4444
"datasets",
4545
"diffusers[torch]~=0.17.1",

0 commit comments

Comments
 (0)