Skip to content

Commit 58fc824

Browse files
authored
add: better warning messages when handling multiple conditionings. (#2804)
* add: better warning messages when handling multiple conditioning. * fix: handling of controlnet_conditioning_scale
1 parent fab4f3d commit 58fc824

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,15 +537,27 @@ def check_inputs(
537537
f" {negative_prompt_embeds.shape}."
538538
)
539539

540-
# Check `image`
540+
# `prompt` needs more sophisticated handling when there are multiple
541+
# conditionings.
542+
if isinstance(self.controlnet, MultiControlNetModel):
543+
if isinstance(prompt, list):
544+
logger.warning(
545+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
546+
" prompts. The conditionings will be fixed across the prompts."
547+
)
541548

549+
# Check `image`
542550
if isinstance(self.controlnet, ControlNetModel):
543551
self.check_image(image, prompt, prompt_embeds)
544552
elif isinstance(self.controlnet, MultiControlNetModel):
545553
if not isinstance(image, list):
546554
raise TypeError("For multiple controlnets: `image` must be type `list`")
547555

548-
if len(image) != len(self.controlnet.nets):
556+
# When `image` is a nested list:
557+
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
558+
elif any(isinstance(i, list) for i in image):
559+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
560+
elif len(image) != len(self.controlnet.nets):
549561
raise ValueError(
550562
"For multiple controlnets: `image` must have the same length as the number of controlnets."
551563
)
@@ -556,12 +568,14 @@ def check_inputs(
556568
assert False
557569

558570
# Check `controlnet_conditioning_scale`
559-
560571
if isinstance(self.controlnet, ControlNetModel):
561572
if not isinstance(controlnet_conditioning_scale, float):
562573
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
563574
elif isinstance(self.controlnet, MultiControlNetModel):
564-
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
575+
if isinstance(controlnet_conditioning_scale, list):
576+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
577+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
578+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
565579
self.controlnet.nets
566580
):
567581
raise ValueError(

0 commit comments

Comments
 (0)