diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index a9bf4a6c1698..2a2c2031a7b5 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -99,7 +99,12 @@ def convert_models(model_path: str, output_path: str, opset: int): unet_path = output_path / "unet" / "model.onnx" onnx_export( pipeline.unet, - model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), + model_args=( + torch.randn(2, pipeline.unet.in_channels, 64, 64), + torch.LongTensor([0, 1]), + torch.randn(2, 77, 768), + False, + ), output_path=unet_path, ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 3e23ef29eef9..30f8d7fcc3b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -5,7 +5,6 @@ import torch import PIL -from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -16,28 +15,29 @@ from . import StableDiffusionPipelineOutput -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 +NUM_UNET_INPUT_CHANNELS = 9 +NUM_LATENT_CHANNELS = 4 + + +def prepare_mask_and_masked_image(image, mask, latents_shape): + image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 + image = image.astype(np.float32) / 127.5 - 1.0 + image_mask = np.array(mask.convert("L")) + masked_image = image * (image_mask < 127.5) -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - return mask + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST) + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask, masked_image class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): @@ -129,14 +129,16 @@ def __init__( def __call__( self, prompt: Union[str, List[str]], - init_image: Union[np.ndarray, PIL.Image.Image], - mask_image: Union[np.ndarray, PIL.Image.Image], - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, + eta: float = 0.0, + latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -149,22 +151,21 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - init_image (`np.ndarray` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`np.ndarray` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -179,6 +180,10 @@ def __call__( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -206,8 +211,8 @@ def __call__( else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) @@ -285,41 +290,46 @@ def __call__( # to avoid doing two forward passes text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) - # preprocess image - if not isinstance(init_image, torch.FloatTensor): - init_image = preprocess_image(init_image) + num_channels_latents = NUM_LATENT_CHANNELS + latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + latents = np.random.randn(*latents_shape).astype(latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - # encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=init_image)[0] - init_latents = 0.18215 * init_latents + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) + mask = mask.astype(latents.dtype) + masked_image = masked_image.astype(latents.dtype) - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt, axis=0) - init_latents_orig = init_latents + masked_image_latents = self.vae_encoder(sample=masked_image)[0] + masked_image_latents = 0.18215 * masked_image_latents - # preprocess mask - if not isinstance(mask_image, np.ndarray): - mask_image = preprocess_mask(mask_image) - mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt) + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + unet_input_channels = NUM_UNET_INPUT_CHANNELS + if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) - timesteps = self.scheduler.timesteps.numpy()[-init_timestep] - timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) - # add noise to latents using the timesteps - noise = np.random.randn(*init_latents.shape).astype(np.float32) - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ) - init_latents = init_latents.numpy() + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -330,15 +340,13 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - latents = init_latents - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].numpy() - - for i, t in tqdm(enumerate(timesteps)): + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.numpy() # predict the noise residual noise_pred = self.unet( @@ -353,12 +361,6 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = latents.numpy() - # masking - init_latents_proper = self.scheduler.add_noise( - torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t]) - ) - - latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ce9a02bca150..615444425c44 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -49,6 +49,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 2a581e612613..357fb11a0e8c 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -2271,7 +2271,7 @@ def test_stable_diffusion_inpaint_onnx(self): ) pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" + "runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None) @@ -2280,9 +2280,8 @@ def test_stable_diffusion_inpaint_onnx(self): np.random.seed(0) output = pipe( prompt=prompt, - init_image=init_image, + image=init_image, mask_image=mask_image, - strength=0.75, guidance_scale=7.5, num_inference_steps=8, output_type="np", @@ -2291,7 +2290,7 @@ def test_stable_diffusion_inpaint_onnx(self): image_slice = images[0, 255:258, 255:258, -1] assert images.shape == (1, 512, 512, 3) - expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734]) + expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @slow