1- # InvokeAI nodes for ControlNet image preprocessors
1+ # Invocations for ControlNet image preprocessors
22# initial implementation by Gregg Helt, 2023
33# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
44from builtins import float , bool
55
6+ import cv2
67import numpy as np
7- from typing import Literal , Optional , Union , List
8+ from typing import Literal , Optional , Union , List , Dict
89from PIL import Image , ImageFilter , ImageOps
910from pydantic import BaseModel , Field , validator
1011
2930 ContentShuffleDetector ,
3031 ZoeDetector ,
3132 MediapipeFaceDetector ,
33+ SamDetector ,
34+ LeresDetector ,
3235)
3336
37+ from controlnet_aux .util import HWC3 , ade_palette
38+
39+
3440from .image import ImageOutput , PILInvocationConfig
3541
3642CONTROLNET_DEFAULT_MODELS = [
95101
96102CONTROLNET_NAME_VALUES = Literal [tuple (CONTROLNET_DEFAULT_MODELS )]
97103CONTROLNET_MODE_VALUES = Literal [tuple (["balanced" , "more_prompt" , "more_control" , "unbalanced" ])]
104+ # crop and fill options not ready yet
105+ # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
106+
98107
99108class ControlField (BaseModel ):
100109 image : ImageField = Field (default = None , description = "The control image" )
@@ -105,7 +114,8 @@ class ControlField(BaseModel):
105114 description = "When the ControlNet is first applied (% of total steps)" )
106115 end_step_percent : float = Field (default = 1 , ge = 0 , le = 1 ,
107116 description = "When the ControlNet is last applied (% of total steps)" )
108- control_mode : CONTROLNET_MODE_VALUES = Field (default = "balanced" , description = "The contorl mode to use" )
117+ control_mode : CONTROLNET_MODE_VALUES = Field (default = "balanced" , description = "The control mode to use" )
118+ # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
109119
110120 @validator ("control_weight" )
111121 def abs_le_one (cls , v ):
@@ -180,7 +190,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput:
180190 ),
181191 )
182192
183- # TODO: move image processors to separate file (image_analysis.py
193+
184194class ImageProcessorInvocation (BaseInvocation , PILInvocationConfig ):
185195 """Base class for invocations that preprocess images for ControlNet"""
186196
@@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
452462 # fmt: on
453463
454464 def run_processor (self , image ):
465+ # MediaPipeFaceDetector throws an error if image has alpha channel
466+ # so convert to RGB if needed
467+ if image .mode == 'RGBA' :
468+ image = image .convert ('RGB' )
455469 mediapipe_face_processor = MediapipeFaceDetector ()
456470 processed_image = mediapipe_face_processor (image , max_faces = self .max_faces , min_confidence = self .min_confidence )
457471 return processed_image
472+
473+ class LeresImageProcessorInvocation (ImageProcessorInvocation , PILInvocationConfig ):
474+ """Applies leres processing to image"""
475+ # fmt: off
476+ type : Literal ["leres_image_processor" ] = "leres_image_processor"
477+ # Inputs
478+ thr_a : float = Field (default = 0 , description = "Leres parameter `thr_a`" )
479+ thr_b : float = Field (default = 0 , description = "Leres parameter `thr_b`" )
480+ boost : bool = Field (default = False , description = "Whether to use boost mode" )
481+ detect_resolution : int = Field (default = 512 , ge = 0 , description = "The pixel resolution for detection" )
482+ image_resolution : int = Field (default = 512 , ge = 0 , description = "The pixel resolution for the output image" )
483+ # fmt: on
484+
485+ def run_processor (self , image ):
486+ leres_processor = LeresDetector .from_pretrained ("lllyasviel/Annotators" )
487+ processed_image = leres_processor (image ,
488+ thr_a = self .thr_a ,
489+ thr_b = self .thr_b ,
490+ boost = self .boost ,
491+ detect_resolution = self .detect_resolution ,
492+ image_resolution = self .image_resolution )
493+ return processed_image
494+
495+
496+ class TileResamplerProcessorInvocation (ImageProcessorInvocation , PILInvocationConfig ):
497+
498+ # fmt: off
499+ type : Literal ["tile_image_processor" ] = "tile_image_processor"
500+ # Inputs
501+ #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
502+ down_sampling_rate : float = Field (default = 1.0 , ge = 1.0 , le = 8.0 , description = "Down sampling rate" )
503+ # fmt: on
504+
505+ # tile_resample copied from sd-webui-controlnet/scripts/processor.py
506+ def tile_resample (self ,
507+ np_img : np .ndarray ,
508+ res = 512 , # never used?
509+ down_sampling_rate = 1.0 ,
510+ ):
511+ np_img = HWC3 (np_img )
512+ if down_sampling_rate < 1.1 :
513+ return np_img
514+ H , W , C = np_img .shape
515+ H = int (float (H ) / float (down_sampling_rate ))
516+ W = int (float (W ) / float (down_sampling_rate ))
517+ np_img = cv2 .resize (np_img , (W , H ), interpolation = cv2 .INTER_AREA )
518+ return np_img
519+
520+ def run_processor (self , img ):
521+ np_img = np .array (img , dtype = np .uint8 )
522+ processed_np_image = self .tile_resample (np_img ,
523+ #res=self.tile_size,
524+ down_sampling_rate = self .down_sampling_rate
525+ )
526+ processed_image = Image .fromarray (processed_np_image )
527+ return processed_image
528+
529+
530+
531+
532+ class SegmentAnythingProcessorInvocation (ImageProcessorInvocation , PILInvocationConfig ):
533+ """Applies segment anything processing to image"""
534+ # fmt: off
535+ type : Literal ["segment_anything_processor" ] = "segment_anything_processor"
536+ # fmt: on
537+
538+ def run_processor (self , image ):
539+ # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
540+ segment_anything_processor = SamDetectorReproducibleColors .from_pretrained ("ybelkada/segment-anything" , subfolder = "checkpoints" )
541+ np_img = np .array (image , dtype = np .uint8 )
542+ processed_image = segment_anything_processor (np_img )
543+ return processed_image
544+
545+ class SamDetectorReproducibleColors (SamDetector ):
546+
547+ # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
548+ # base class show_anns() method randomizes colors,
549+ # which seems to also lead to non-reproducible image generation
550+ # so using ADE20k color palette instead
551+ def show_anns (self , anns : List [Dict ]):
552+ if len (anns ) == 0 :
553+ return
554+ sorted_anns = sorted (anns , key = (lambda x : x ['area' ]), reverse = True )
555+ h , w = anns [0 ]['segmentation' ].shape
556+ final_img = Image .fromarray (np .zeros ((h , w , 3 ), dtype = np .uint8 ), mode = "RGB" )
557+ palette = ade_palette ()
558+ for i , ann in enumerate (sorted_anns ):
559+ m = ann ['segmentation' ]
560+ img = np .empty ((m .shape [0 ], m .shape [1 ], 3 ), dtype = np .uint8 )
561+ # doing modulo just in case number of annotated regions exceeds number of colors in palette
562+ ann_color = palette [i % len (palette )]
563+ img [:, :] = ann_color
564+ final_img .paste (Image .fromarray (img , mode = "RGB" ), (0 , 0 ), Image .fromarray (np .uint8 (m * 255 )))
565+ return np .array (final_img , dtype = np .uint8 )
0 commit comments