From 4cb28c982cb61b6abfd5627eb9c8e3ec3ec90257 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 18 Mar 2023 04:42:10 +0900 Subject: [PATCH 1/9] add guess mode (WIP) --- src/diffusers/models/controlnet.py | 25 +++++++++++++++++-- .../pipeline_stable_diffusion_controlnet.py | 4 +++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 0d59605fe046..c804ae7daf99 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -449,6 +449,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: # check channel order @@ -467,6 +468,14 @@ def forward( attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + if guess_mode: + assert sample.shape[0] == controlnet_cond.shape[0] == encoder_hidden_states.shape[0] + assert sample.shape[0] == 2 # TODO: batch!=2 + # extract cond batch (remove uncond batch) + sample = sample[0, :, :, :].unsqueeze(0) + controlnet_cond = controlnet_cond[0, :, :, :].unsqueeze(0) + encoder_hidden_states = encoder_hidden_states[0, :, :].unsqueeze(0) + # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -549,8 +558,20 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample *= conditioning_scale + if guess_mode: + assert len(down_block_res_samples) == 12 + # magic coeff number from: + # https://github.com/lllyasviel/ControlNet/blob/16ea3b5379c1e78a4bc8e3fc9cae8d65c42511b1/gradio_canny2image.py#L52 + scales = [conditioning_scale * (0.825 ** float(12 - i)) for i in range(13)] + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample *= scales[-1] # last one + + # fill zero to uncond batch + # down_block_res_samples = [torch.cat([d, torch.zeros_like(d)]) for d in down_block_res_samples] + # mid_block_res_sample = torch.cat([mid_block_res_sample, torch.zeros_like(mid_block_res_sample)]) + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale if not return_dict: return (down_block_res_samples, mid_block_res_sample) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index fd82281005ad..7500b60b1a45 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -117,6 +117,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): @@ -130,6 +131,7 @@ def forward( timestep_cond, attention_mask, cross_attention_kwargs, + guess_mode, return_dict, ) @@ -724,6 +726,7 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -917,6 +920,7 @@ def __call__( encoder_hidden_states=prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, return_dict=False, ) From ab298cac85e7b28be629ab941d6c58144bf18599 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sun, 19 Mar 2023 21:56:47 +0900 Subject: [PATCH 2/9] fix uncond/cond order --- src/diffusers/models/controlnet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index c804ae7daf99..6df5b896fe38 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -472,9 +472,9 @@ def forward( assert sample.shape[0] == controlnet_cond.shape[0] == encoder_hidden_states.shape[0] assert sample.shape[0] == 2 # TODO: batch!=2 # extract cond batch (remove uncond batch) - sample = sample[0, :, :, :].unsqueeze(0) - controlnet_cond = controlnet_cond[0, :, :, :].unsqueeze(0) - encoder_hidden_states = encoder_hidden_states[0, :, :].unsqueeze(0) + sample = sample[1, :, :, :].unsqueeze(0) + controlnet_cond = controlnet_cond[1, :, :, :].unsqueeze(0) + encoder_hidden_states = encoder_hidden_states[1, :, :].unsqueeze(0) # 1. time timesteps = timestep @@ -567,8 +567,8 @@ def forward( mid_block_res_sample *= scales[-1] # last one # fill zero to uncond batch - # down_block_res_samples = [torch.cat([d, torch.zeros_like(d)]) for d in down_block_res_samples] - # mid_block_res_sample = torch.cat([mid_block_res_sample, torch.zeros_like(mid_block_res_sample)]) + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale From 0b377e711b73e95b5f5814fdd7db808864c7913f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 8 Apr 2023 01:57:31 +0900 Subject: [PATCH 3/9] support guidance_scale=1.0 and batch != 1 --- src/diffusers/models/controlnet.py | 12 ------- .../pipeline_stable_diffusion_controlnet.py | 32 ++++++++++++++++--- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index dbeeb62c57aa..9336ae39b021 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -475,14 +475,6 @@ def forward( attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - if guess_mode: - assert sample.shape[0] == controlnet_cond.shape[0] == encoder_hidden_states.shape[0] - assert sample.shape[0] == 2 # TODO: batch!=2 - # extract cond batch (remove uncond batch) - sample = sample[1, :, :, :].unsqueeze(0) - controlnet_cond = controlnet_cond[1, :, :, :].unsqueeze(0) - encoder_hidden_states = encoder_hidden_states[1, :, :].unsqueeze(0) - # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -572,10 +564,6 @@ def forward( scales = [conditioning_scale * (0.825 ** float(12 - i)) for i in range(13)] down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample *= scales[-1] # last one - - # fill zero to uncond batch - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index f3094427f87a..d0d39b5fef06 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -629,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds): ) def prepare_image( - self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance, + guess_mode, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): @@ -666,7 +675,7 @@ def prepare_image( image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance: + if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @@ -886,6 +895,7 @@ def __call__( device=device, dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, ) elif isinstance(self.controlnet, MultiControlNetModel): images = [] @@ -900,6 +910,7 @@ def __call__( device=device, dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, ) images.append(image_) @@ -937,16 +948,29 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # only use the cond batch for the controlnet + controlnet_latent_model_input = latents + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + controlnet_latent_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + down_block_res_samples, mid_block_res_sample = self.controlnet( - latent_model_input, + controlnet_latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) + if guess_mode and do_classifier_free_guidance: + # fill zero to uncond batch + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + # predict the noise residual noise_pred = self.unet( latent_model_input, From a2b7e4246bd8bcffff77a4342979b8001f9af07c Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 8 Apr 2023 02:20:03 +0900 Subject: [PATCH 4/9] remove magic coeff --- src/diffusers/models/controlnet.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 9336ae39b021..4f1ffe604578 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -558,10 +558,8 @@ def forward( # 6. scaling if guess_mode: - assert len(down_block_res_samples) == 12 - # magic coeff number from: - # https://github.com/lllyasviel/ControlNet/blob/16ea3b5379c1e78a4bc8e3fc9cae8d65c42511b1/gradio_canny2image.py#L52 - scales = [conditioning_scale * (0.825 ** float(12 - i)) for i in range(13)] + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0 + scales *= conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample *= scales[-1] # last one else: From 335075424af2b7f10de9f160c018520b0c2a8952 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 8 Apr 2023 02:29:20 +0900 Subject: [PATCH 5/9] add docstring --- .../stable_diffusion/pipeline_stable_diffusion_controlnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index d0d39b5fef06..5709f15e058f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -831,6 +831,10 @@ def __call__( The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + Examples: Returns: From 1828cdf356d5286a9789f5b7001cfaa7a66210b9 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 8 Apr 2023 02:49:24 +0900 Subject: [PATCH 6/9] add intergration test --- .../test_stable_diffusion_controlnet.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index d556e6318f43..5e73692c8d87 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -553,6 +553,38 @@ def test_sequential_cpu_offloading(self): # make sure that less than 7 GB is allocated assert mem_bytes < 4 * 10**9 + def test_canny_guess_mode(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=3.0, + guess_mode=True, + ) + + image = output.images[0] + assert image.shape == (768, 512, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu From 1819689b417b66795f920a5286299acba5e3d4c3 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Thu, 13 Apr 2023 00:22:09 +0900 Subject: [PATCH 7/9] add document to controlnet.mdx --- .../pipelines/stable_diffusion/controlnet.mdx | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx index 5a4cfa41ca43..73cfbee61c9e 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -242,6 +242,41 @@ image.save("./multi_controlnet_output.png") +### Guess Mode + +Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states: + +>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts. + +#### The core implementation: + +It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`. + +Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired. + +#### Usage: + +Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode). +```py +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel +import torch + +controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") +pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to( + "cuda" +) +image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0] +image.save("guess_mode_generated.png") +``` + +#### Output image comparison: +Canny Control Example +|no guess_mode with prompt|guess_mode without prompt| +|---|---| +||| + + + ## Available checkpoints ControlNet requires a *control image* in addition to the text-to-image *prompt*. From c70329429c73fe620ce40d8dd2786ee5fabf27c1 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Thu, 13 Apr 2023 00:48:22 +0900 Subject: [PATCH 8/9] made the comments a bit more explanatory --- .../pipeline_stable_diffusion_controlnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 5709f15e058f..c9484afec959 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -953,7 +953,7 @@ def __call__( # controlnet(s) inference if guess_mode and do_classifier_free_guidance: - # only use the cond batch for the controlnet + # Infer ControlNet only for the conditional batch. controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: @@ -971,7 +971,9 @@ def __call__( ) if guess_mode and do_classifier_free_guidance: - # fill zero to uncond batch + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) From 81fe0000c9a0f9e4362129469241b5ceea21da57 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Thu, 13 Apr 2023 01:07:47 +0900 Subject: [PATCH 9/9] fix table --- docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx index 73cfbee61c9e..af859177c002 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -271,6 +271,7 @@ image.save("guess_mode_generated.png") #### Output image comparison: Canny Control Example + |no guess_mode with prompt|guess_mode without prompt| |---|---| |||