-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
(Follow up to #2135)
If I load a model checkpoint using the from_pretrained(...)
method which isn't correctly matched to the right DiffusionPipeline
subclass, such as using the runwayml/stable-diffusion-v1-5
text-to-image Stable Diffusion checkpoint with the StableDiffusionInpaintPipeline
inpainting pipeline, the call to from_pretrained(...)
succeeds, but if I try to generate an image from the resulting pipeline, I run into an error (as expected).
As a concrete example, if I follow the StableDiffusionInpaintPipeline
example with the runwayml/stable-diffusion-v1-5
checkpoint instead of the runwayml/stable-diffusion-inpainting
checkpoint:
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe.to("cuda")
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
I get an error on the last line (note that this is on 0.15.0.dev0
):
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module> │
│ │
│ /home/tamamo/miniconda3/envs/diffusers-dev/lib/python3.10/site-packages/torch/autograd/grad_mode │
│ .py:27 in decorate_context │
│ │
│ 24 │ │ @functools.wraps(func) │
│ 25 │ │ def decorate_context(*args, **kwargs): │
│ 26 │ │ │ with self.clone(): │
│ ❱ 27 │ │ │ │ return func(*args, **kwargs) │
│ 28 │ │ return cast(F, decorate_context) │
│ 29 │ │
│ 30 │ def _wrap_generator(self, func): │
│ │
│ /home/tamamo/code/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_i │
│ npaint.py:834 in __call__ │
│ │
│ 831 │ │ num_channels_mask = mask.shape[1] │
│ 832 │ │ num_channels_masked_image = masked_image_latents.shape[1] │
│ 833 │ │ if num_channels_latents + num_channels_mask + num_channels_masked_image != self. │
│ ❱ 834 │ │ │ raise ValueError( │
│ 835 │ │ │ │ f"Incorrect configuration settings! The config of `pipeline.unet`: {self │
│ 836 │ │ │ │ f" {self.unet.config.in_channels} but received `num_channels_latents`: { │
│ 837 │ │ │ │ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Incorrect configuration settings! The config of `pipeline.unet`: FrozenDict([('sample_size', 64),
('in_channels', 4), ('out_channels', 4), ('center_input_sample', False), ('flip_sin_to_cos', True), ('freq_shift', 0),
('down_block_types', ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']),
('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D',
'CrossAttnUpBlock2D']), ('only_cross_attention', False), ('block_out_channels', [320, 640, 1280, 1280]),
('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups',
32), ('norm_eps', 1e-05), ('cross_attention_dim', 768), ('attention_head_dim', 8), ('dual_cross_attention', False),
('use_linear_projection', False), ('class_embed_type', None), ('num_class_embeds', None), ('upcast_attention', False),
('resnet_time_scale_shift', 'default'), ('time_embedding_type', 'positional'), ('timestep_post_act', None),
('time_cond_proj_dim', None), ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim',
None), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.6.0'), ('_name_or_path',
'/home/tamamo/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d
819/unet')]) expects 4 but received `num_channels_latents`: 4 + `num_channels_mask`: 1 + `num_channels_masked_image`: 4 = 9.
Please verify the config of `pipeline.unet` or your `mask_image` or `image` input.
Given this, I have the following questions:
- Should the
DiffusionPipeline.from_pretrained(...)
method catch errors where we load an incompatible model checkpoint with a givenDiffusionPipeline
subclass? (I am aware that a workaround is to always use the parentDiffusionPipeline
class rather than a child class when callingfrom_pretrained(...)
, since it will properly infer the correct pipeline class for the checkpoint; see this comment.) - Should the error message be improved to be more user friendly? (See e.g. this comment for more context.)