diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index 8e0b58c56df3..f0e0b178af20 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path = Path(output_path) # TEXT ENCODER + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( "A sample prompt", padding="max_length", @@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F del pipeline.text_encoder # UNET + unet_in_channels = pipeline.unet.config.in_channels + unet_sample_size = pipeline.unet.config.sample_size unet_path = output_path / "unet" / "model.onnx" onnx_export( pipeline.unet, model_args=( - torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype), - torch.LongTensor([0, 1]).to(device=device), - torch.randn(2, 77, 768).to(device=device, dtype=dtype), + torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + torch.randn(2).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), False, ), output_path=unet_path, @@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # VAE ENCODER vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() onnx_export( vae_encoder, - model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False), + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype), + False, + ), output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], @@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # VAE DECODER vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part vae_decoder.forward = vae_encoder.decode onnx_export( vae_decoder, - model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False), + model_args=( + torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, + ), output_path=output_path / "vae_decoder" / "model.onnx", ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], @@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F del pipeline.vae # SAFETY CHECKER - safety_checker = pipeline.safety_checker - safety_checker.forward = safety_checker.forward_onnx - onnx_export( - pipeline.safety_checker, - model_args=( - torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype), - torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype), - ), - output_path=output_path / "safety_checker" / "model.onnx", - ordered_input_names=["clip_input", "images"], - output_names=["out_images", "has_nsfw_concepts"], - dynamic_axes={ - "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, - }, - opset=opset, - ) - del pipeline.safety_checker + if pipeline.safety_checker is not None: + safety_checker = pipeline.safety_checker + clip_num_channels = safety_checker.config.vision_config.num_channels + clip_image_size = safety_checker.config.vision_config.image_size + safety_checker.forward = safety_checker.forward_onnx + onnx_export( + pipeline.safety_checker, + model_args=( + torch.randn( + 1, + clip_num_channels, + clip_image_size, + clip_image_size, + ).to(device=device, dtype=dtype), + torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype), + ), + output_path=output_path / "safety_checker" / "model.onnx", + ordered_input_names=["clip_input", "images"], + output_names=["out_images", "has_nsfw_concepts"], + dynamic_axes={ + "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, + }, + opset=opset, + ) + del pipeline.safety_checker + safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") + else: + safety_checker = None onnx_pipeline = OnnxStableDiffusionPipeline( vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), @@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F tokenizer=pipeline.tokenizer, unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), scheduler=pipeline.scheduler, - safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"), + safety_checker=safety_checker, feature_extractor=pipeline.feature_extractor, ) diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py index 142174f6e101..b2c533ed741f 100644 --- a/src/diffusers/onnx_utils.py +++ b/src/diffusers/onnx_utils.py @@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download -from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging +from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): @@ -33,13 +33,28 @@ logger = logging.get_logger(__name__) +ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, +} + class OnnxRuntimeModel: def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model self.model_save_dir = kwargs.get("model_save_dir", None) - self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) def __call__(self, **kwargs): inputs = {k: np.array(v) for k, v in kwargs.items()} @@ -84,6 +99,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional except shutil.SameFileError: pass + # copy external weights (for models >2GB) + src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + if src_path.exists(): + dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + try: + shutil.copyfile(src_path, dst_path) + except shutil.SameFileError: + pass + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 97e196e72338..e328c1d65cde 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -541,7 +541,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: + if not is_pipeline_module and passed_class_obj[name] is not None: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 22f5bf6c432d..0c50e424e249 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -2,11 +2,12 @@ from typing import Callable, List, Optional, Union import numpy as np +import torch from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import OnnxRuntimeModel +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging @@ -186,7 +187,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - latents = latents * self.scheduler.init_noise_sigma + latents = latents * np.float(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -197,15 +198,20 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual - noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings - ) + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings) noise_pred = noise_pred[0] # perform guidance @@ -214,7 +220,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample latents = np.array(latents) # call the callback, if provided @@ -235,6 +241,9 @@ def __call__( safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) + + image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image) + # There will throw an error if use safety_checker batchsize>1 images, has_nsfw_concept = [], [] for i in range(image.shape[0]): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 5e6b2e6f2fc3..e0582bbbce80 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -8,7 +8,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import OnnxRuntimeModel +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging @@ -338,14 +338,21 @@ def __call__( t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings )[0] # perform guidance @@ -354,7 +361,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample latents = latents.numpy() # call the callback, if provided @@ -375,7 +382,7 @@ def __call__( safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) - # There will throw an error if use safety_checker batchsize>1 + # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 2ce9831a1676..a9ff88c8c063 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -8,7 +8,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import OnnxRuntimeModel +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging @@ -352,7 +352,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + latents = latents * np.float(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -363,17 +363,23 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latnets in the channel dimension latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.numpy() + latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings )[0] # perform guidance @@ -382,7 +388,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample latents = latents.numpy() # call the callback, if provided @@ -403,7 +409,7 @@ def __call__( safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) - # There will throw an error if use safety_checker batchsize>1 + # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3fa477e7dce8..a00e1f4dcd4c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,7 @@ WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" +ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py index d8356675e9b3..a1946e39f912 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest import numpy as np -from diffusers import OnnxStableDiffusionPipeline +from diffusers import DDIMScheduler, LMSDiscreteScheduler, OnnxStableDiffusionPipeline from diffusers.utils.testing_utils import is_onnx_available, require_onnxruntime, require_torch_gpu, slow from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin @@ -36,32 +37,87 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes @require_onnxruntime @require_torch_gpu class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): - def test_inference(self): - provider = ( + @property + def gpu_provider(self): + return ( "CUDAExecutionProvider", { - "gpu_mem_limit": "17179869184", # 16GB. + "gpu_mem_limit": "15000000000", # 15GB "arena_extend_strategy": "kSameAsRequested", }, ) + + @property + def gpu_options(self): options = ort.SessionOptions() options.enable_mem_pattern = False + return options + + def test_inference_default_pndm(self): + # using the PNDM scheduler by default sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", - provider=provider, - sess_options=options, + provider=self.gpu_provider, + sess_options=self.gpu_options, ) + sd_pipe.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" np.random.seed(0) - output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np") + output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=10, output_type="np") image = output.images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327]) + expected_slice = np.array([0.0452, 0.0390, 0.0087, 0.0350, 0.0617, 0.0364, 0.0544, 0.0523, 0.0720]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_inference_ddim(self): + ddim_scheduler = DDIMScheduler.from_config( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" + ) + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + scheduler=ddim_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "open neural network exchange" + generator = np.random.RandomState(0) + output = sd_pipe([prompt], guidance_scale=7.5, num_inference_steps=10, generator=generator, output_type="np") + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2867, 0.1974, 0.1481, 0.7294, 0.7251, 0.6667, 0.4194, 0.5642, 0.6486]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_inference_k_lms(self): + lms_scheduler = LMSDiscreteScheduler.from_config( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" + ) + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + scheduler=lms_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "open neural network exchange" + generator = np.random.RandomState(0) + output = sd_pipe([prompt], guidance_scale=7.5, num_inference_steps=10, generator=generator, output_type="np") + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2306, 0.1959, 0.1593, 0.6549, 0.6394, 0.5408, 0.5065, 0.6010, 0.6161]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_intermediate_state(self): @@ -75,27 +131,61 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array( - [-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592] + [-0.6772, -0.3835, -1.2456, 0.1905, -1.0974, 0.6967, -1.9353, 0.0178, 1.0167] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 elif step == 5: assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array( - [-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264] + [-0.3351, 0.2241, -0.1837, -0.2325, -0.6577, 0.3393, -0.0241, 0.5899, 1.3875] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False pipe = OnnxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider" + "runwayml/stable-diffusion-v1-5", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, ) pipe.set_progress_bar_config(disable=None) prompt = "Andromeda galaxy in a bottle" - np.random.seed(0) - pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) + generator = np.random.RandomState(0) + pipe( + prompt=prompt, + num_inference_steps=5, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) assert test_callback_fn.has_been_called assert number_of_steps == 6 + + def test_stable_diffusion_no_safety_checker(self): + pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + safety_checker=None, + ) + assert isinstance(pipe, OnnxStableDiffusionPipeline) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = OnnxStableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py index 3ffbfc3d4f18..61831c64c0ae 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py @@ -17,7 +17,7 @@ import numpy as np -from diffusers import OnnxStableDiffusionImg2ImgPipeline +from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionImg2ImgPipeline from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin @@ -35,45 +35,92 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes @slow @require_onnxruntime @require_torch_gpu -class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): - def test_inference(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/sketch-mountains-input.jpg" - ) - init_image = init_image.resize((768, 512)) - provider = ( +class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( "CUDAExecutionProvider", { - "gpu_mem_limit": "17179869184", # 16GB. + "gpu_mem_limit": "15000000000", # 15GB "arena_extend_strategy": "kSameAsRequested", }, ) + + @property + def gpu_options(self): options = ort.SessionOptions() options.enable_mem_pattern = False + return options + + def test_inference_default_pndm(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + # using the PNDM scheduler by default pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", - provider=provider, - sess_options=options, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A fantasy landscape, trending on artstation" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + num_inference_steps=10, + generator=generator, + output_type="np", + ) + images = output.images + image_slice = images[0, 255:258, 383:386, -1] + + assert images.shape == (1, 512, 768, 3) + expected_slice = np.array([0.4909, 0.5059, 0.5372, 0.4623, 0.4876, 0.5049, 0.4820, 0.4956, 0.5019]) + # TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + def test_inference_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + lms_scheduler = LMSDiscreteScheduler.from_config( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" + ) + pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + scheduler=lms_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, ) pipe.set_progress_bar_config(disable=None) prompt = "A fantasy landscape, trending on artstation" - np.random.seed(0) + generator = np.random.RandomState(0) output = pipe( prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, - num_inference_steps=8, + num_inference_steps=10, + generator=generator, output_type="np", ) images = output.images image_slice = images[0, 255:258, 383:386, -1] assert images.shape == (1, 512, 768, 3) - expected_slice = np.array([0.4830, 0.5242, 0.5603, 0.5016, 0.5131, 0.5111, 0.4928, 0.5025, 0.5055]) + expected_slice = np.array([0.7950, 0.7923, 0.7903, 0.5516, 0.5501, 0.5476, 0.4965, 0.4933, 0.4910]) # TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py index 81cbed4e510d..4ba8e273b497 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py @@ -17,7 +17,7 @@ import numpy as np -from diffusers import OnnxStableDiffusionInpaintPipeline +from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin @@ -35,8 +35,24 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes @slow @require_onnxruntime @require_torch_gpu -class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): - def test_stable_diffusion_inpaint_onnx(self): +class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( + "CUDAExecutionProvider", + { + "gpu_mem_limit": "15000000000", # 15GB + "arena_extend_strategy": "kSameAsRequested", + }, + ) + + @property + def gpu_options(self): + options = ort.SessionOptions() + options.enable_mem_pattern = False + return options + + def test_inference_default_pndm(self): init_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo.png" @@ -45,37 +61,69 @@ def test_stable_diffusion_inpaint_onnx(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" ) - provider = ( - "CUDAExecutionProvider", - { - "gpu_mem_limit": "17179869184", # 16GB. - "arena_extend_strategy": "kSameAsRequested", - }, + pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A red cat sitting on a park bench" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + guidance_scale=7.5, + num_inference_steps=10, + generator=generator, + output_type="np", + ) + images = output.images + image_slice = images[0, 255:258, 255:258, -1] + + assert images.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2514, 0.3007, 0.3517, 0.1790, 0.2382, 0.3167, 0.1944, 0.2273, 0.2464]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_inference_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + lms_scheduler = LMSDiscreteScheduler.from_config( + "runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx" ) - options = ort.SessionOptions() - options.enable_mem_pattern = False pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="onnx", - provider=provider, - sess_options=options, + scheduler=lms_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, ) pipe.set_progress_bar_config(disable=None) prompt = "A red cat sitting on a park bench" - np.random.seed(0) + generator = np.random.RandomState(0) output = pipe( prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=7.5, - num_inference_steps=8, + num_inference_steps=10, + generator=generator, output_type="np", ) images = output.images image_slice = images[0, 255:258, 255:258, -1] assert images.shape == (1, 512, 512, 3) - expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799]) + expected_slice = np.array([0.2520, 0.2743, 0.2643, 0.2641, 0.2517, 0.2650, 0.2498, 0.2688, 0.2529]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3