Skip to content

custom_pipeline="lpw_stable_diffusion" won't work with pipe.enable_model_cpu_offload() #3182

@cian0

Description

@cian0

Describe the bug

As per title, everything works until adding the line to use long prompt weighting pipeline when calling enable_model_cpu_offload()

Reproduction

use enable_model_cpu_offload() with custom_pipeline="lpw_stable_diffusion"

import torch
from diffusers import StableDiffusionPipeline

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)

pipe = StableDiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V1.4_Fantasy.ai",  
    vae=vae,
    custom_pipeline="lpw_stable_diffusion",
    torch_dtype=torch.float16,
    safety_checker=None,
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
pipe.safety_checker=None
pipe.enable_xformers_memory_efficient_attention()

negative_prompt = "cropped, head out of frame, (semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
prompt = "photoshoot of a man in suit and tie looking confident at the office, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
seed = seed
torch.manual_seed(seed)

image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=40, guidance_scale=7).images[0]

Logs

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:25                                                                                   │
│                                                                                                  │
│   22 seed = seed                                                                                 │
│   23 torch.manual_seed(seed)                                                                     │
│   24 # image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]                  │
│ ❱ 25 image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=40, guidance_s    │
│   26                                                                                             │
│   27 # image = pipe(prompt).images[0]                                                            │
│   28                                                                                             │
│                                                                                                  │
│ /media/ian/extras/condaenvs/diffusers_mult_controlnet_pytorch13/lib/python3.9/site-packages/torc │
│ h/autograd/grad_mode.py:27 in decorate_context                                                   │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /media/ian/extras/huggingface_cache/modules/diffusers_modules/git/lpw_stable_diffusion.py:778 in │
│ __call__                                                                                         │
│                                                                                                  │
│    775 │   │   do_classifier_free_guidance = guidance_scale > 1.0                                │
│    776 │   │                                                                                     │
│    777 │   │   # 3. Encode input prompt                                                          │
│ ❱  778 │   │   text_embeddings = self._encode_prompt(                                            │
│    779 │   │   │   prompt,                                                                       │
│    780 │   │   │   device,                                                                       │
│    781 │   │   │   num_images_per_prompt,                                                        │
│                                                                                                  │
│ /media/ian/extras/huggingface_cache/modules/diffusers_modules/git/lpw_stable_diffusion.py:542 in │
│ _encode_prompt                                                                                   │
│                                                                                                  │
│    539 │   │   │   │   " the batch size of `prompt`."                                            │
│    540 │   │   │   )                                                                             │
│    541 │   │                                                                                     │
│ ❱  542 │   │   text_embeddings, uncond_embeddings = get_weighted_text_embeddings(                │
│    543 │   │   │   pipe=self,                                                                    │
│    544 │   │   │   prompt=prompt,                                                                │
│    545 │   │   │   uncond_prompt=negative_prompt if do_classifier_free_guidance else None,       │
│                                                                                                  │
│ /media/ian/extras/huggingface_cache/modules/diffusers_modules/git/lpw_stable_diffusion.py:366 in │
│ get_weighted_text_embeddings                                                                     │
│                                                                                                  │
│    363 │   # TODO: should we normalize by chunk or in a whole (current implementation)?          │
│    364 │   if (not skip_parsing) and (not skip_weighting):                                       │
│    365 │   │   previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.d  │
│ ❱  366 │   │   text_embeddings *= prompt_weights.unsqueeze(-1)                                   │
│    367 │   │   current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dt  │
│    368 │   │   text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)     │
│    369 │   │   if uncond_prompt is not None:                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

System Info

  • diffusers version: 0.16.0.dev0
  • Platform: Linux-5.19.0-40-generic-x86_64-with-glibc2.35
  • Python version: 3.9.16
  • PyTorch version (GPU?): 1.13.1 (True)
  • Huggingface_hub version: 0.13.4
  • Transformers version: 4.25.1
  • Accelerate version: 0.18.0
  • xFormers version: 0.0.18
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions