Skip to content

Runtime error when trying to explicitly pass image embeddings to StableUnCLIPImg2ImgPipeline #2844

@unishift

Description

@unishift

Describe the bug

I'm trying to use StableUnCLIPImg2ImgPipeline with modified image embeddings. When trying to pass embeddings explicitly instead of image, the error is raised.

The fix is simple, just check image_embeds is None instead of casting image_embeds to bool. I'll create a pull request with the fix.

Reproduction

import torch
from diffusers import StableUnCLIPImg2ImgPipeline

pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16"
)
pipe.to("cuda")

embedding = torch.randn(1, 1024, device="cuda", dtype=torch.float16)

pipe(image_embeds=embedding)

Logs

Traceback (most recent call last):
  File "/home/experiments/remix/stable-diffusion-remix/bug.py", line 12, in <module>
    pipe(image_embeds=embedding)
  File "/home/experiments/remix/stable-diffusion-remix/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/experiments/remix/stable-diffusion-remix/.venv/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py", line 750, in __call__
    image_embeds = self._encode_image(
  File "/home/experiments/remix/stable-diffusion-remix/.venv/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py", line 391, in _encode_image
    if not image_embeds:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous


### System Info

- `diffusers` version: 0.15.0.dev0
- Platform: Linux-3.10.0-862.14.4.el7.x86_64-x86_64-with-glibc2.27
- Python version: 3.10.8
- PyTorch version (GPU?): 2.0.0+cu117 (True)
- Huggingface_hub version: 0.13.3
- Transformers version: 4.27.3
- Accelerate version: 0.18.0
- xFormers version: not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions