Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 49 additions & 24 deletions scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
output_path = Path(output_path)

# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
Comment on lines +84 to +85
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding config values instead of hardcoded ones throughout the script

text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
Expand All @@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
del pipeline.text_encoder

# UNET
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
torch.LongTensor([0, 1]).to(device=device),
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
False,
),
output_path=unet_path,
Expand Down Expand Up @@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F

# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
Expand All @@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F

# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
Expand All @@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
del pipeline.vae

# SAFETY CHECKER
safety_checker = pipeline.safety_checker
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
},
opset=opset,
)
del pipeline.safety_checker
if pipeline.safety_checker is not None:
safety_checker = pipeline.safety_checker
clip_num_channels = safety_checker.config.vision_config.num_channels
clip_image_size = safety_checker.config.vision_config.image_size
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(
torch.randn(
1,
clip_num_channels,
clip_image_size,
clip_image_size,
).to(device=device, dtype=dtype),
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
},
opset=opset,
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
else:
safety_checker = None

onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
Expand All @@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor,
)

Expand Down
28 changes: 26 additions & 2 deletions src/diffusers/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from huggingface_hub import hf_hub_download

from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging


if is_onnx_available():
Expand All @@ -33,13 +33,28 @@

logger = logging.get_logger(__name__)

ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}


class OnnxRuntimeModel:
def __init__(self, model=None, **kwargs):
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
self.model = model
self.model_save_dir = kwargs.get("model_save_dir", None)
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)

def __call__(self, **kwargs):
inputs = {k: np.array(v) for k, v in kwargs.items()}
Expand Down Expand Up @@ -84,6 +99,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional
except shutil.SameFileError:
pass

# copy external weights (for models >2GB)
src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
if src_path.exists():
dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
try:
shutil.copyfile(src_path, dst_path)
except shutil.SameFileError:
pass
Comment on lines +102 to +109
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The external weights.pb wasn't copied before, breaking from_pretrained after save_pretrained


def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
if not is_pipeline_module and passed_class_obj[name] is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since safety_checker is OnnxRuntimeModel and not a pipeline module like in the pytorch version, it produced an error further in this branch when safety_checker==None.
So let's just skip anything that's None here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable, List, Optional, Union

import numpy as np
import torch

from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -186,7 +187,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

latents = latents * self.scheduler.init_noise_sigma
latents = latents * np.float(self.scheduler.init_noise_sigma)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -197,15 +198,20 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
Comment on lines +201 to +204
Copy link
Member Author

@anton-l anton-l Nov 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing float timesteps (as in K-LMS) to the onnx unet will result in a type error. This adds a bit of back-compatibility until we update checkpoints to accept float timesteps.
The onnx export script will set the type to float by default.

(Can't change the timesteps to float right away for the hub checkpoints, that will break pipelines from older versions of diffusers)


for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
)
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings)
noise_pred = noise_pred[0]

# perform guidance
Expand All @@ -214,7 +220,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = np.array(latents)

# call the callback, if provided
Expand All @@ -235,6 +241,9 @@ def __call__(
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)

image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image)

# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -338,14 +338,21 @@ def __call__(
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].numpy()

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
)[0]

# perform guidance
Expand All @@ -354,7 +361,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()

# call the callback, if provided
Expand All @@ -375,7 +382,7 @@ def __call__(
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
# safety_checker does not support batched inputs yet
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -352,7 +352,7 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
latents = latents * np.float(self.scheduler.init_noise_sigma)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -363,17 +363,23 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.numpy()
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
)[0]

# perform guidance
Expand All @@ -382,7 +388,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()

# call the callback, if provided
Expand All @@ -403,7 +409,7 @@ def __call__(
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
# safety_checker does not support batched inputs yet
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
Expand Down
Loading