-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[ONNX] Improve ONNXPipeline scheduler compatibility, fix safety_checker #1173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |
|
|
||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging | ||
| from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging | ||
|
|
||
|
|
||
| if is_onnx_available(): | ||
|
|
@@ -33,13 +33,28 @@ | |
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| ORT_TO_NP_TYPE = { | ||
| "tensor(bool)": np.bool_, | ||
| "tensor(int8)": np.int8, | ||
| "tensor(uint8)": np.uint8, | ||
| "tensor(int16)": np.int16, | ||
| "tensor(uint16)": np.uint16, | ||
| "tensor(int32)": np.int32, | ||
| "tensor(uint32)": np.uint32, | ||
| "tensor(int64)": np.int64, | ||
| "tensor(uint64)": np.uint64, | ||
| "tensor(float16)": np.float16, | ||
| "tensor(float)": np.float32, | ||
| "tensor(double)": np.float64, | ||
| } | ||
|
|
||
|
|
||
| class OnnxRuntimeModel: | ||
| def __init__(self, model=None, **kwargs): | ||
| logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") | ||
| self.model = model | ||
| self.model_save_dir = kwargs.get("model_save_dir", None) | ||
| self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") | ||
| self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) | ||
|
|
||
| def __call__(self, **kwargs): | ||
| inputs = {k: np.array(v) for k, v in kwargs.items()} | ||
|
|
@@ -84,6 +99,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional | |
| except shutil.SameFileError: | ||
| pass | ||
|
|
||
| # copy external weights (for models >2GB) | ||
| src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) | ||
| if src_path.exists(): | ||
| dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) | ||
| try: | ||
| shutil.copyfile(src_path, dst_path) | ||
| except shutil.SameFileError: | ||
| pass | ||
|
Comment on lines
+102
to
+109
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The external |
||
|
|
||
| def save_pretrained( | ||
| self, | ||
| save_directory: Union[str, os.PathLike], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -541,7 +541,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| # if the model is in a pipeline module, then we load it from the pipeline | ||
| if name in passed_class_obj: | ||
| # 1. check that passed_class_obj has correct parent class | ||
| if not is_pipeline_module: | ||
| if not is_pipeline_module and passed_class_obj[name] is not None: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok! |
||
| library = importlib.import_module(library_name) | ||
| class_obj = getattr(library, class_name) | ||
| importable_classes = LOADABLE_CLASSES[library_name] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,12 @@ | |
| from typing import Callable, List, Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from transformers import CLIPFeatureExtractor, CLIPTokenizer | ||
|
|
||
| from ...configuration_utils import FrozenDict | ||
| from ...onnx_utils import OnnxRuntimeModel | ||
| from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel | ||
| from ...pipeline_utils import DiffusionPipeline | ||
| from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | ||
| from ...utils import deprecate, logging | ||
|
|
@@ -186,7 +187,7 @@ def __call__( | |
| # set timesteps | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
|
|
||
| latents = latents * self.scheduler.init_noise_sigma | ||
| latents = latents * np.float(self.scheduler.init_noise_sigma) | ||
|
|
||
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
|
|
@@ -197,15 +198,20 @@ def __call__( | |
| if accepts_eta: | ||
| extra_step_kwargs["eta"] = eta | ||
|
|
||
| timestep_dtype = next( | ||
| (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" | ||
| ) | ||
| timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] | ||
|
Comment on lines
+201
to
+204
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing float timesteps (as in K-LMS) to the onnx unet will result in a type error. This adds a bit of back-compatibility until we update checkpoints to accept (Can't change the timesteps to |
||
|
|
||
| for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): | ||
| # expand the latents if we are doing classifier free guidance | ||
| latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents | ||
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
| latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) | ||
| latent_model_input = latent_model_input.cpu().numpy() | ||
|
|
||
| # predict the noise residual | ||
| noise_pred = self.unet( | ||
| sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings | ||
| ) | ||
| timestep = np.array([t], dtype=timestep_dtype) | ||
| noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings) | ||
| noise_pred = noise_pred[0] | ||
|
|
||
| # perform guidance | ||
|
|
@@ -214,7 +220,7 @@ def __call__( | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
|
||
| # compute the previous noisy sample x_t -> x_t-1 | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
| latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample | ||
| latents = np.array(latents) | ||
|
|
||
| # call the callback, if provided | ||
|
|
@@ -235,6 +241,9 @@ def __call__( | |
| safety_checker_input = self.feature_extractor( | ||
| self.numpy_to_pil(image), return_tensors="np" | ||
| ).pixel_values.astype(image.dtype) | ||
|
|
||
| image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image) | ||
|
|
||
| # There will throw an error if use safety_checker batchsize>1 | ||
| images, has_nsfw_concept = [], [] | ||
| for i in range(image.shape[0]): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding config values instead of hardcoded ones throughout the script