Skip to content

KDPM2DiscreteScheduler produces latent noise in image2image when called with certain inference steps #3116

@lstein

Description

@lstein

Describe the bug

When the KPM2DiscreteScheduler is used in an image2image generation, under certain combinations of strength and num_inference_steps the resulting image will be latent noise rather than the expected image. The issue appears to occur when img2img_pipeline.scheduler.set_timesteps builds a timesteps vector with an odd length.

image

Reproduction

The following script, which is adapted from https://huggingface.co/docs/diffusers/using-diffusers/img2img, reproduces the error. Set steps to 30 to generate a noise image. Set steps to 31 to see the desired image.

import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline, KDPM2DiscreteScheduler

# Steps 30 results in a bad image. Steps 31 gives desired results.
# Any combination of steps and strength that result in an odd-numbered timesteps
# vector length fails
steps = 30
strength = 0.5

device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "nitrosocke/Ghibli-Diffusion",
    torch_dtype=torch.float16,
    safety_checker=None,   # am getting false positives on this image!
).to(
    device
)
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image.thumbnail((768, 768))


k_dpm2 = KDPM2DiscreteScheduler.from_config(pipe.scheduler.config)
pipe.scheduler = k_dpm2

prompt = "ghibli style, a fantasy landscape with castles"
generator = torch.Generator(device=device).manual_seed(1024)

image = pipe(prompt=prompt,
             num_inference_steps=steps,
             strength=strength,
             image=init_image,
             guidance_scale=7.5,
             generator=generator
             ).images[0]
image.show()

Logs

No response

System Info

  • diffusers version: 0.15.0
  • Platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Huggingface_hub version: 0.13.4
  • Transformers version: 4.26.1
  • Accelerate version: 0.16.0
  • xFormers version: 0.0.16
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions