Skip to content

Commit 8556a84

Browse files
anton-lkumquatexpress
authored andcommitted
ONNX supervised inpainting (huggingface#906)
* ONNX supervised inpainting * sync with the torch pipeline * fix concat * update ref values * back to 8 steps * type fix * make fix-copies
1 parent 2d9713b commit 8556a84

File tree

4 files changed

+108
-87
lines changed

4 files changed

+108
-87
lines changed

scripts/convert_stable_diffusion_checkpoint_to_onnx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def convert_models(model_path: str, output_path: str, opset: int):
9999
unet_path = output_path / "unet" / "model.onnx"
100100
onnx_export(
101101
pipeline.unet,
102-
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
102+
model_args=(
103+
torch.randn(2, pipeline.unet.in_channels, 64, 64),
104+
torch.LongTensor([0, 1]),
105+
torch.randn(2, 77, 768),
106+
False,
107+
),
103108
output_path=unet_path,
104109
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
105110
output_names=["out_sample"], # has to be different from "sample" for correct tracing

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 84 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66

77
import PIL
8-
from tqdm.auto import tqdm
98
from transformers import CLIPFeatureExtractor, CLIPTokenizer
109

1110
from ...configuration_utils import FrozenDict
@@ -16,28 +15,29 @@
1615
from . import StableDiffusionPipelineOutput
1716

1817

19-
logger = logging.get_logger(__name__)
18+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2019

2120

22-
def preprocess_image(image):
23-
w, h = image.size
24-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
25-
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
26-
image = np.array(image).astype(np.float32) / 255.0
21+
NUM_UNET_INPUT_CHANNELS = 9
22+
NUM_LATENT_CHANNELS = 4
23+
24+
25+
def prepare_mask_and_masked_image(image, mask, latents_shape):
26+
image = np.array(image.convert("RGB"))
2727
image = image[None].transpose(0, 3, 1, 2)
28-
return 2.0 * image - 1.0
28+
image = image.astype(np.float32) / 127.5 - 1.0
2929

30+
image_mask = np.array(mask.convert("L"))
31+
masked_image = image * (image_mask < 127.5)
3032

31-
def preprocess_mask(mask):
32-
mask = mask.convert("L")
33-
w, h = mask.size
34-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
35-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
36-
mask = np.array(mask).astype(np.float32) / 255.0
37-
mask = np.tile(mask, (4, 1, 1))
38-
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
39-
mask = 1 - mask # repaint white, keep black
40-
return mask
33+
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
34+
mask = np.array(mask.convert("L"))
35+
mask = mask.astype(np.float32) / 255.0
36+
mask = mask[None, None]
37+
mask[mask < 0.5] = 0
38+
mask[mask >= 0.5] = 1
39+
40+
return mask, masked_image
4141

4242

4343
class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
@@ -129,14 +129,16 @@ def __init__(
129129
def __call__(
130130
self,
131131
prompt: Union[str, List[str]],
132-
init_image: Union[np.ndarray, PIL.Image.Image],
133-
mask_image: Union[np.ndarray, PIL.Image.Image],
134-
strength: float = 0.8,
135-
num_inference_steps: Optional[int] = 50,
136-
guidance_scale: Optional[float] = 7.5,
132+
image: PIL.Image.Image,
133+
mask_image: PIL.Image.Image,
134+
height: int = 512,
135+
width: int = 512,
136+
num_inference_steps: int = 50,
137+
guidance_scale: float = 7.5,
137138
negative_prompt: Optional[Union[str, List[str]]] = None,
138139
num_images_per_prompt: Optional[int] = 1,
139-
eta: Optional[float] = 0.0,
140+
eta: float = 0.0,
141+
latents: Optional[np.ndarray] = None,
140142
output_type: Optional[str] = "pil",
141143
return_dict: bool = True,
142144
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
@@ -149,22 +151,21 @@ def __call__(
149151
Args:
150152
prompt (`str` or `List[str]`):
151153
The prompt or prompts to guide the image generation.
152-
init_image (`np.ndarray` or `PIL.Image.Image`):
153-
`Image`, or tensor representing an image batch, that will be used as the starting point for the
154-
process. This is the image whose masked region will be inpainted.
155-
mask_image (`np.ndarray` or `PIL.Image.Image`):
156-
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
157-
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
158-
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
159-
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
160-
strength (`float`, *optional*, defaults to 0.8):
161-
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
162-
is 1, the denoising process will be run on the masked area for the full number of iterations specified
163-
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
164-
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
154+
image (`PIL.Image.Image`):
155+
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
156+
be masked out with `mask_image` and repainted according to `prompt`.
157+
mask_image (`PIL.Image.Image`):
158+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
159+
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
160+
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
161+
instead of 3, so the expected shape would be `(B, H, W, 1)`.
162+
height (`int`, *optional*, defaults to 512):
163+
The height in pixels of the generated image.
164+
width (`int`, *optional*, defaults to 512):
165+
The width in pixels of the generated image.
165166
num_inference_steps (`int`, *optional*, defaults to 50):
166-
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
167-
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
167+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
168+
expense of slower inference.
168169
guidance_scale (`float`, *optional*, defaults to 7.5):
169170
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
170171
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -179,6 +180,10 @@ def __call__(
179180
eta (`float`, *optional*, defaults to 0.0):
180181
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
181182
[`schedulers.DDIMScheduler`], will be ignored for others.
183+
latents (`np.ndarray`, *optional*):
184+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
185+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
186+
tensor will ge generated by sampling using the supplied random `generator`.
182187
output_type (`str`, *optional*, defaults to `"pil"`):
183188
The output format of the generate image. Choose between
184189
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -206,8 +211,8 @@ def __call__(
206211
else:
207212
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
208213

209-
if strength < 0 or strength > 1:
210-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
214+
if height % 8 != 0 or width % 8 != 0:
215+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
211216

212217
if (callback_steps is None) or (
213218
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@@ -285,41 +290,46 @@ def __call__(
285290
# to avoid doing two forward passes
286291
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
287292

288-
# preprocess image
289-
if not isinstance(init_image, torch.FloatTensor):
290-
init_image = preprocess_image(init_image)
293+
num_channels_latents = NUM_LATENT_CHANNELS
294+
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
295+
latents_dtype = text_embeddings.dtype
296+
if latents is None:
297+
latents = np.random.randn(*latents_shape).astype(latents_dtype)
298+
else:
299+
if latents.shape != latents_shape:
300+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
291301

292-
# encode the init image into latents and scale the latents
293-
init_latents = self.vae_encoder(sample=init_image)[0]
294-
init_latents = 0.18215 * init_latents
302+
# prepare mask and masked_image
303+
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:])
304+
mask = mask.astype(latents.dtype)
305+
masked_image = masked_image.astype(latents.dtype)
295306

296-
# Expand init_latents for batch_size and num_images_per_prompt
297-
init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt, axis=0)
298-
init_latents_orig = init_latents
307+
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
308+
masked_image_latents = 0.18215 * masked_image_latents
299309

300-
# preprocess mask
301-
if not isinstance(mask_image, np.ndarray):
302-
mask_image = preprocess_mask(mask_image)
303-
mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt)
310+
mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
311+
masked_image_latents = (
312+
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
313+
)
304314

305-
# check sizes
306-
if not mask.shape == init_latents.shape:
307-
raise ValueError("The mask and init_image should be the same size!")
315+
num_channels_mask = mask.shape[1]
316+
num_channels_masked_image = masked_image_latents.shape[1]
308317

309-
# get the original timestep using init_timestep
310-
offset = self.scheduler.config.get("steps_offset", 0)
311-
init_timestep = int(num_inference_steps * strength) + offset
312-
init_timestep = min(init_timestep, num_inference_steps)
318+
unet_input_channels = NUM_UNET_INPUT_CHANNELS
319+
if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels:
320+
raise ValueError(
321+
"Incorrect configuration settings! The config of `pipeline.unet` expects"
322+
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
323+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
324+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
325+
" `pipeline.unet` or your `mask_image` or `image` input."
326+
)
313327

314-
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
315-
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
328+
# set timesteps
329+
self.scheduler.set_timesteps(num_inference_steps)
316330

317-
# add noise to latents using the timesteps
318-
noise = np.random.randn(*init_latents.shape).astype(np.float32)
319-
init_latents = self.scheduler.add_noise(
320-
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
321-
)
322-
init_latents = init_latents.numpy()
331+
# scale the initial noise by the standard deviation required by the scheduler
332+
latents = latents * self.scheduler.init_noise_sigma
323333

324334
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
325335
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -330,15 +340,13 @@ def __call__(
330340
if accepts_eta:
331341
extra_step_kwargs["eta"] = eta
332342

333-
latents = init_latents
334-
335-
t_start = max(num_inference_steps - init_timestep + offset, 0)
336-
timesteps = self.scheduler.timesteps[t_start:].numpy()
337-
338-
for i, t in tqdm(enumerate(timesteps)):
343+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
339344
# expand the latents if we are doing classifier free guidance
340345
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
341-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
346+
# concat latents, mask, masked_image_latnets in the channel dimension
347+
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
348+
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
349+
latent_model_input = latent_model_input.numpy()
342350

343351
# predict the noise residual
344352
noise_pred = self.unet(
@@ -353,12 +361,6 @@ def __call__(
353361
# compute the previous noisy sample x_t -> x_t-1
354362
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
355363
latents = latents.numpy()
356-
# masking
357-
init_latents_proper = self.scheduler.add_noise(
358-
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
359-
)
360-
361-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
362364

363365
# call the callback, if provided
364366
if callback is not None and i % callback_steps == 0:

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ def from_pretrained(cls, *args, **kwargs):
4949
requires_backends(cls, ["torch", "transformers"])
5050

5151

52+
class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
53+
_backends = ["torch", "transformers"]
54+
55+
def __init__(self, *args, **kwargs):
56+
requires_backends(self, ["torch", "transformers"])
57+
58+
@classmethod
59+
def from_config(cls, *args, **kwargs):
60+
requires_backends(cls, ["torch", "transformers"])
61+
62+
@classmethod
63+
def from_pretrained(cls, *args, **kwargs):
64+
requires_backends(cls, ["torch", "transformers"])
65+
66+
5267
class StableDiffusionPipeline(metaclass=DummyObject):
5368
_backends = ["torch", "transformers"]
5469

tests/test_pipelines.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,7 +2271,7 @@ def test_stable_diffusion_inpaint_onnx(self):
22712271
)
22722272

22732273
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
2274-
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
2274+
"runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider"
22752275
)
22762276
pipe.set_progress_bar_config(disable=None)
22772277

@@ -2280,9 +2280,8 @@ def test_stable_diffusion_inpaint_onnx(self):
22802280
np.random.seed(0)
22812281
output = pipe(
22822282
prompt=prompt,
2283-
init_image=init_image,
2283+
image=init_image,
22842284
mask_image=mask_image,
2285-
strength=0.75,
22862285
guidance_scale=7.5,
22872286
num_inference_steps=8,
22882287
output_type="np",
@@ -2291,7 +2290,7 @@ def test_stable_diffusion_inpaint_onnx(self):
22912290
image_slice = images[0, 255:258, 255:258, -1]
22922291

22932292
assert images.shape == (1, 512, 512, 3)
2294-
expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734])
2293+
expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799])
22952294
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
22962295

22972296
@slow

0 commit comments

Comments
 (0)