Skip to content

from_pipe method is not working with MultiControlNetModel pipeline component #7982

@detkov

Description

@detkov

Describe the bug

When trying to init a StableDiffusionXLControlNetInpaintPipeline with from_pipe method from the StableDiffusionXLControlNetPipeline with multiple controlnets, an error, which basically says that only ControlNet class is supported, occurs.

If using from_pipe(pipeline, controlnet=pipeline.controlnet) this error doesn't occur, but this is not an expected behavior IMHO.

Reproduction

from diffusers import StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, AutoencoderKL
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
import torch


controlnet_1 = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
)
controlnet_2 = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16
)
controlnets = MultiControlNetModel([controlnet_1, controlnet_2])

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", vae=vae, controlnet=controlnets, variant='fp16', torch_dtype=torch.float16, add_watermarker=False,
).to('cuda')
pipeline.enable_vae_slicing()

pipeline_inpaint = StableDiffusionXLControlNetInpaintPipeline.from_pipe(pipeline)

Logs

component controlnet is not switched over to new pipeline because type does not match the expected. controlnet is <class 'diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel'> while the new pipeline expect (<class 'diffusers.models.controlnet.ControlNetModel'>,). please pass the component of the correct type to the new pipeline. `from_pipe(..., controlnet=controlnet)`
Traceback (most recent call last):
  File "/root/text2image-core/temp_bug.py", line 20, in <module>
    pipeline_inpaint = StableDiffusionXLControlNetInpaintPipeline.from_pipe(pipeline)
  File "/opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 1868, in from_pipe
    raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline'> expected {'text_encoder', 'tokenizer_2', 'unet', 'text_encoder_2', 'tokenizer', 'scheduler', 'controlnet', 'vae'}, but only {'text_encoder', 'tokenizer_2', 'unet', 'text_encoder_2', 'tokenizer', 'scheduler', 'vae'} were passed

System Info

  • 🤗 Diffusers version: 0.28.0.dev0 (b2140a8)
  • Platform: Ubuntu 22.04.4 LTS - Linux-5.15.0-1055-aws-x86_64-with-glibc2.35
  • Running on a notebook?: No
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.23.0
  • Transformers version: 4.40.1
  • Accelerate version: 0.29.3
  • PEFT version: 0.10.0
  • Bitsandbytes version: 0.43.1
  • Safetensors version: 0.4.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A10G, 23028 MiB VRAM
  • Using GPU in script?: YES
  • Using distributed or parallel set-up in script?: NO

Who can help?

@yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions