Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ def convert_models(model_path: str, output_path: str, opset: int):
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
model_args=(
torch.randn(2, pipeline.unet.in_channels, 64, 64),
torch.LongTensor([0, 1]),
torch.randn(2, 77, 768),
False,
),
output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch

import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand All @@ -16,28 +15,29 @@
from . import StableDiffusionPipelineOutput


logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4


def prepare_mask_and_masked_image(image, mask, latents_shape):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
image = image.astype(np.float32) / 127.5 - 1.0

image_mask = np.array(mask.convert("L"))
masked_image = image * (image_mask < 127.5)

def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
return mask
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1

return mask, masked_image


class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -129,14 +129,16 @@ def __init__(
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[np.ndarray, PIL.Image.Image],
mask_image: Union[np.ndarray, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
eta: float = 0.0,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
Expand All @@ -149,22 +151,21 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
mask_image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand All @@ -179,6 +180,10 @@ def __call__(
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand Down Expand Up @@ -206,8 +211,8 @@ def __call__(
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
Expand Down Expand Up @@ -285,41 +290,46 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])

# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = np.random.randn(*latents_shape).astype(latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

# encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents
# prepare mask and masked_image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:])
mask = mask.astype(latents.dtype)
masked_image = masked_image.astype(latents.dtype)

# Expand init_latents for batch_size and num_images_per_prompt
init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt, axis=0)
init_latents_orig = init_latents
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
masked_image_latents = 0.18215 * masked_image_latents

# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image)
mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt)
mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)

# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
unet_input_channels = NUM_UNET_INPUT_CHANNELS
if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels:
raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)

timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# add noise to latents using the timesteps
noise = np.random.randn(*init_latents.shape).astype(np.float32)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
)
init_latents = init_latents.numpy()
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -330,15 +340,13 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].numpy()

for i, t in tqdm(enumerate(timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.numpy()

# predict the noise residual
noise_pred = self.unet(
Expand All @@ -353,12 +361,6 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = latents.numpy()
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
)

latents = (init_latents_proper * mask) + (latents * (1 - mask))

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

Expand Down
7 changes: 3 additions & 4 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,7 +2271,7 @@ def test_stable_diffusion_inpaint_onnx(self):
)

pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
"runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)

Expand All @@ -2280,9 +2280,8 @@ def test_stable_diffusion_inpaint_onnx(self):
np.random.seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
num_inference_steps=8,
output_type="np",
Expand All @@ -2291,7 +2290,7 @@ def test_stable_diffusion_inpaint_onnx(self):
image_slice = images[0, 255:258, 255:258, -1]

assert images.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734])
expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
Expand Down