Skip to content

DPMSolverMultistepScheduler with AutoPipelineForImage2Image fails at specific combinations of step counts and strength #9366

@frankjoshua

Description

@frankjoshua

Describe the bug

When using DPMSolverMultistepScheduler and certain combinations of step counts and prompt strength I get a crash. The issue has be reproduced with multiple models.

Our use case is image refining. So the prompt strength is low. If we could figure out the pattern and filter the input that would be helpful also.

Some examples:
num_inference_steps = 100, strength = 0.1 -> CRASH
num_inference_steps = 120, strength = 0.1 -> PASS
num_inference_steps = 90, strength = 0.1 -> PASS
num_inference_steps = 50, strength = 0.05 -> PASS
num_inference_steps = 20, strength = 0.5 -> PASS
num_inference_steps = 200, strength = 0.05 -> CRASH

Traceback (most recent call last):
File "/root/app/sd_playground/sd_img_2_img.py", line 24, in
image = pipeline(prompt, num_inference_steps=100, image=init_image, strength=0.1).images[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py", line 1412, in call
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 991, in step
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 724, in multistep_dpm_solver_second_order_update
self.sigmas[self.step_index + 1],
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
IndexError: index 101 is out of bounds for dimension 0 with size 101

Reproduction

import torch
from diffusers import AutoPipelineForImage2Image, DPMSolverMultistepScheduler
from diffusers.utils import load_image

pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
                    pipeline.scheduler.config, 
                    use_karras_sigmas=True,
                    sde_type="sde-dpmsolver++",
                    euler_at_final=True,
                    use_lu_lambdas=True
                )
pipeline.enable_model_cpu_offload()

# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
init_image = load_image(url)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

# pass prompt and image to pipeline
image = pipeline(prompt, num_inference_steps=100, image=init_image, strength=0.1).images[0]
image.save("sd_i2i.png")

Logs

root@josh-office:~/app/sd_playground# python sd_img_2_img.py 
2024-09-04 23:56:46.401813: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-04 23:56:46.413201: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-04 23:56:46.416746: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-04 23:56:46.426485: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-04 23:56:47.100092: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 11.11it/s]
 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌             | 9/10 [00:03<00:00,  2.26it/s]
Traceback (most recent call last):
  File "/root/app/sd_playground/sd_img_2_img.py", line 24, in <module>
    image = pipeline(prompt, num_inference_steps=100, image=init_image, strength=0.1).images[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py", line 1412, in __call__
    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 991, in step
    prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 724, in multistep_dpm_solver_second_order_update
    self.sigmas[self.step_index + 1],
    ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
IndexError: index 101 is out of bounds for dimension 0 with size 101

System Info

  • 🤗 Diffusers version: 0.30.1
  • Platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.24.6
  • Transformers version: 4.44.0
  • Accelerate version: 0.33.0
  • PEFT version: 0.12.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.4
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce GTX 1660, 6144 MiB
    NVIDIA GeForce RTX 3090, 24576 MiB
  • Using GPU in script?: Nvidia RTX 3090
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions