Skip to content

Raise Error on Mismatched Model Checkpoint and Pipeline #2799

@dg845

Description

@dg845

(Follow up to #2135)

If I load a model checkpoint using the from_pretrained(...) method which isn't correctly matched to the right DiffusionPipeline subclass, such as using the runwayml/stable-diffusion-v1-5 text-to-image Stable Diffusion checkpoint with the StableDiffusionInpaintPipeline inpainting pipeline, the call to from_pretrained(...) succeeds, but if I try to generate an image from the resulting pipeline, I run into an error (as expected).

As a concrete example, if I follow the StableDiffusionInpaintPipeline example with the runwayml/stable-diffusion-v1-5 checkpoint instead of the runwayml/stable-diffusion-inpainting checkpoint:

import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline

def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)

pipe.to("cuda")

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]

I get an error on the last line (note that this is on 0.15.0.dev0):

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>                                                                                      │
│                                                                                                  │
│ /home/tamamo/miniconda3/envs/diffusers-dev/lib/python3.10/site-packages/torch/autograd/grad_mode │
│ .py:27 in decorate_context                                                                       │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /home/tamamo/code/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_i │
│ npaint.py:834 in __call__                                                                        │
│                                                                                                  │
│   831 │   │   num_channels_mask = mask.shape[1]                                                  │
│   832 │   │   num_channels_masked_image = masked_image_latents.shape[1]                          │
│   833 │   │   if num_channels_latents + num_channels_mask + num_channels_masked_image != self.   │
│ ❱ 834 │   │   │   raise ValueError(                                                              │
│   835 │   │   │   │   f"Incorrect configuration settings! The config of `pipeline.unet`: {self   │
│   836 │   │   │   │   f" {self.unet.config.in_channels} but received `num_channels_latents`: {   │
│   837 │   │   │   │   f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Incorrect configuration settings! The config of `pipeline.unet`: FrozenDict([('sample_size', 64),
('in_channels', 4), ('out_channels', 4), ('center_input_sample', False), ('flip_sin_to_cos', True), ('freq_shift', 0),
('down_block_types', ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']),
('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D',
'CrossAttnUpBlock2D']), ('only_cross_attention', False), ('block_out_channels', [320, 640, 1280, 1280]),
('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups',
32), ('norm_eps', 1e-05), ('cross_attention_dim', 768), ('attention_head_dim', 8), ('dual_cross_attention', False),
('use_linear_projection', False), ('class_embed_type', None), ('num_class_embeds', None), ('upcast_attention', False),
('resnet_time_scale_shift', 'default'), ('time_embedding_type', 'positional'), ('timestep_post_act', None),
('time_cond_proj_dim', None), ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim',
None), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.6.0'), ('_name_or_path',
'/home/tamamo/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d
819/unet')]) expects 4 but received `num_channels_latents`: 4 + `num_channels_mask`: 1 + `num_channels_masked_image`: 4 = 9.
Please verify the config of `pipeline.unet` or your `mask_image` or `image` input.

Given this, I have the following questions:

  1. Should the DiffusionPipeline.from_pretrained(...) method catch errors where we load an incompatible model checkpoint with a given DiffusionPipeline subclass? (I am aware that a workaround is to always use the parent DiffusionPipeline class rather than a child class when calling from_pretrained(...), since it will properly infer the correct pipeline class for the checkpoint; see this comment.)
  2. Should the error message be improved to be more user friendly? (See e.g. this comment for more context.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions